summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/__init__.py17
-rw-r--r--lib/sqlalchemy/connectors/mxodbc.py39
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py77
-rw-r--r--lib/sqlalchemy/connectors/zxJDBC.py31
-rw-r--r--lib/sqlalchemy/databases/__init__.py15
-rw-r--r--lib/sqlalchemy/dialects/__init__.py22
-rw-r--r--lib/sqlalchemy/dialects/firebird/__init__.py33
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py569
-rw-r--r--lib/sqlalchemy/dialects/firebird/fdb.py21
-rw-r--r--lib/sqlalchemy/dialects/firebird/kinterbasdb.py76
-rw-r--r--lib/sqlalchemy/dialects/mssql/__init__.py76
-rw-r--r--lib/sqlalchemy/dialects/mssql/adodbapi.py26
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py1238
-rw-r--r--lib/sqlalchemy/dialects/mssql/information_schema.py210
-rw-r--r--lib/sqlalchemy/dialects/mssql/mxodbc.py17
-rw-r--r--lib/sqlalchemy/dialects/mssql/pymssql.py38
-rw-r--r--lib/sqlalchemy/dialects/mssql/pyodbc.py102
-rw-r--r--lib/sqlalchemy/dialects/mssql/zxjdbc.py16
-rw-r--r--lib/sqlalchemy/dialects/mysql/__init__.py103
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py1397
-rw-r--r--lib/sqlalchemy/dialects/mysql/cymysql.py24
-rw-r--r--lib/sqlalchemy/dialects/mysql/dml.py23
-rw-r--r--lib/sqlalchemy/dialects/mysql/enumerated.py71
-rw-r--r--lib/sqlalchemy/dialects/mysql/gaerdbms.py13
-rw-r--r--lib/sqlalchemy/dialects/mysql/json.py11
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqlconnector.py125
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqldb.py119
-rw-r--r--lib/sqlalchemy/dialects/mysql/oursql.py113
-rw-r--r--lib/sqlalchemy/dialects/mysql/pymysql.py12
-rw-r--r--lib/sqlalchemy/dialects/mysql/pyodbc.py14
-rw-r--r--lib/sqlalchemy/dialects/mysql/reflection.py342
-rw-r--r--lib/sqlalchemy/dialects/mysql/types.py146
-rw-r--r--lib/sqlalchemy/dialects/mysql/zxjdbc.py29
-rw-r--r--lib/sqlalchemy/dialects/oracle/__init__.py53
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py1023
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py293
-rw-r--r--lib/sqlalchemy/dialects/oracle/zxjdbc.py102
-rw-r--r--lib/sqlalchemy/dialects/postgresql/__init__.py111
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py79
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py1437
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py95
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py44
-rw-r--r--lib/sqlalchemy/dialects/postgresql/hstore.py129
-rw-r--r--lib/sqlalchemy/dialects/postgresql/json.py66
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py109
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py190
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py10
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pygresql.py72
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pypostgresql.py28
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ranges.py57
-rw-r--r--lib/sqlalchemy/dialects/postgresql/zxjdbc.py9
-rw-r--r--lib/sqlalchemy/dialects/sqlite/__init__.py40
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py908
-rw-r--r--lib/sqlalchemy/dialects/sqlite/json.py11
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlcipher.py32
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py42
-rw-r--r--lib/sqlalchemy/dialects/sybase/__init__.py62
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py685
-rw-r--r--lib/sqlalchemy/dialects/sybase/mxodbc.py1
-rw-r--r--lib/sqlalchemy/dialects/sybase/pyodbc.py18
-rw-r--r--lib/sqlalchemy/dialects/sybase/pysybase.py34
-rw-r--r--lib/sqlalchemy/engine/__init__.py30
-rw-r--r--lib/sqlalchemy/engine/base.py445
-rw-r--r--lib/sqlalchemy/engine/default.py416
-rw-r--r--lib/sqlalchemy/engine/interfaces.py51
-rw-r--r--lib/sqlalchemy/engine/reflection.py436
-rw-r--r--lib/sqlalchemy/engine/result.py384
-rw-r--r--lib/sqlalchemy/engine/strategies.py120
-rw-r--r--lib/sqlalchemy/engine/threadlocal.py50
-rw-r--r--lib/sqlalchemy/engine/url.py115
-rw-r--r--lib/sqlalchemy/engine/util.py16
-rw-r--r--lib/sqlalchemy/event/api.py11
-rw-r--r--lib/sqlalchemy/event/attr.py81
-rw-r--r--lib/sqlalchemy/event/base.py45
-rw-r--r--lib/sqlalchemy/event/legacy.py72
-rw-r--r--lib/sqlalchemy/event/registry.py52
-rw-r--r--lib/sqlalchemy/events.py74
-rw-r--r--lib/sqlalchemy/exc.py121
-rw-r--r--lib/sqlalchemy/ext/__init__.py1
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py288
-rw-r--r--lib/sqlalchemy/ext/automap.py309
-rw-r--r--lib/sqlalchemy/ext/baked.py135
-rw-r--r--lib/sqlalchemy/ext/compiler.py30
-rw-r--r--lib/sqlalchemy/ext/declarative/__init__.py35
-rw-r--r--lib/sqlalchemy/ext/declarative/api.py117
-rw-r--r--lib/sqlalchemy/ext/declarative/base.py326
-rw-r--r--lib/sqlalchemy/ext/declarative/clsregistry.py125
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py61
-rw-r--r--lib/sqlalchemy/ext/hybrid.py30
-rw-r--r--lib/sqlalchemy/ext/indexable.py12
-rw-r--r--lib/sqlalchemy/ext/instrumentation.py78
-rw-r--r--lib/sqlalchemy/ext/mutable.py78
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py29
-rw-r--r--lib/sqlalchemy/ext/serializer.py26
-rw-r--r--lib/sqlalchemy/inspection.py14
-rw-r--r--lib/sqlalchemy/interfaces.py102
-rw-r--r--lib/sqlalchemy/log.py26
-rw-r--r--lib/sqlalchemy/orm/__init__.py51
-rw-r--r--lib/sqlalchemy/orm/attributes.py781
-rw-r--r--lib/sqlalchemy/orm/base.py128
-rw-r--r--lib/sqlalchemy/orm/collections.py294
-rw-r--r--lib/sqlalchemy/orm/dependency.py1086
-rw-r--r--lib/sqlalchemy/orm/deprecated_interfaces.py150
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py207
-rw-r--r--lib/sqlalchemy/orm/dynamic.py208
-rw-r--r--lib/sqlalchemy/orm/evaluator.py101
-rw-r--r--lib/sqlalchemy/orm/events.py175
-rw-r--r--lib/sqlalchemy/orm/exc.py25
-rw-r--r--lib/sqlalchemy/orm/identity.py18
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py105
-rw-r--r--lib/sqlalchemy/orm/interfaces.py139
-rw-r--r--lib/sqlalchemy/orm/loading.py381
-rw-r--r--lib/sqlalchemy/orm/mapper.py1029
-rw-r--r--lib/sqlalchemy/orm/path_registry.py70
-rw-r--r--lib/sqlalchemy/orm/persistence.py1186
-rw-r--r--lib/sqlalchemy/orm/properties.py137
-rw-r--r--lib/sqlalchemy/orm/query.py978
-rw-r--r--lib/sqlalchemy/orm/relationships.py1277
-rw-r--r--lib/sqlalchemy/orm/scoping.py38
-rw-r--r--lib/sqlalchemy/orm/session.py570
-rw-r--r--lib/sqlalchemy/orm/state.py233
-rw-r--r--lib/sqlalchemy/orm/strategies.py1144
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py310
-rw-r--r--lib/sqlalchemy/orm/sync.py64
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py264
-rw-r--r--lib/sqlalchemy/orm/util.py462
-rw-r--r--lib/sqlalchemy/pool/__init__.py7
-rw-r--r--lib/sqlalchemy/pool/base.py185
-rw-r--r--lib/sqlalchemy/pool/dbapi_proxy.py12
-rw-r--r--lib/sqlalchemy/pool/impl.py142
-rw-r--r--lib/sqlalchemy/processors.py56
-rw-r--r--lib/sqlalchemy/schema.py8
-rw-r--r--lib/sqlalchemy/sql/__init__.py11
-rw-r--r--lib/sqlalchemy/sql/annotation.py16
-rw-r--r--lib/sqlalchemy/sql/base.py114
-rw-r--r--lib/sqlalchemy/sql/compiler.py2030
-rw-r--r--lib/sqlalchemy/sql/crud.py440
-rw-r--r--lib/sqlalchemy/sql/ddl.py306
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py234
-rw-r--r--lib/sqlalchemy/sql/dml.py194
-rw-r--r--lib/sqlalchemy/sql/elements.py800
-rw-r--r--lib/sqlalchemy/sql/expression.py205
-rw-r--r--lib/sqlalchemy/sql/functions.py139
-rw-r--r--lib/sqlalchemy/sql/naming.py47
-rw-r--r--lib/sqlalchemy/sql/operators.py109
-rw-r--r--lib/sqlalchemy/sql/schema.py1129
-rw-r--r--lib/sqlalchemy/sql/selectable.py812
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py649
-rw-r--r--lib/sqlalchemy/sql/type_api.py122
-rw-r--r--lib/sqlalchemy/sql/util.py327
-rw-r--r--lib/sqlalchemy/sql/visitors.py48
-rw-r--r--lib/sqlalchemy/testing/__init__.py61
-rw-r--r--lib/sqlalchemy/testing/assertions.py182
-rw-r--r--lib/sqlalchemy/testing/assertsql.py148
-rw-r--r--lib/sqlalchemy/testing/config.py6
-rw-r--r--lib/sqlalchemy/testing/engines.py60
-rw-r--r--lib/sqlalchemy/testing/entities.py23
-rw-r--r--lib/sqlalchemy/testing/exclusions.py123
-rw-r--r--lib/sqlalchemy/testing/fixtures.py85
-rw-r--r--lib/sqlalchemy/testing/mock.py3
-rw-r--r--lib/sqlalchemy/testing/pickleable.py36
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py7
-rw-r--r--lib/sqlalchemy/testing/plugin/noseplugin.py23
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py359
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py104
-rw-r--r--lib/sqlalchemy/testing/profiling.py71
-rw-r--r--lib/sqlalchemy/testing/provision.py86
-rw-r--r--lib/sqlalchemy/testing/replay_fixture.py81
-rw-r--r--lib/sqlalchemy/testing/requirements.py78
-rw-r--r--lib/sqlalchemy/testing/runner.py2
-rw-r--r--lib/sqlalchemy/testing/schema.py68
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py1
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py132
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py50
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py54
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py221
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py762
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py242
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py483
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py136
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py702
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py37
-rw-r--r--lib/sqlalchemy/testing/util.py66
-rw-r--r--lib/sqlalchemy/testing/warnings.py21
-rw-r--r--lib/sqlalchemy/types.py62
-rw-r--r--lib/sqlalchemy/util/__init__.py170
-rw-r--r--lib/sqlalchemy/util/_collections.py122
-rw-r--r--lib/sqlalchemy/util/compat.py100
-rw-r--r--lib/sqlalchemy/util/deprecations.py35
-rw-r--r--lib/sqlalchemy/util/langhelpers.py423
-rw-r--r--lib/sqlalchemy/util/queue.py2
-rw-r--r--lib/sqlalchemy/util/topological.py12
192 files changed, 24125 insertions, 16313 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index 6162ead5c..171f23028 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -56,7 +56,7 @@ from .sql import (
union_all,
update,
within_group,
- )
+)
from .types import (
ARRAY,
@@ -102,7 +102,7 @@ from .types import (
UnicodeText,
VARBINARY,
VARCHAR,
- )
+)
from .schema import (
@@ -123,14 +123,14 @@ from .schema import (
ThreadLocalMetaData,
UniqueConstraint,
DDL,
- BLANK_SCHEMA
+ BLANK_SCHEMA,
)
from .inspection import inspect
from .engine import create_engine, engine_from_config
-__version__ = '1.3.0b2'
+__version__ = "1.3.0b2"
def __go(lcls):
@@ -141,8 +141,13 @@ def __go(lcls):
import inspect as _inspect
- __all__ = sorted(name for name, obj in lcls.items()
- if not (name.startswith('_') or _inspect.ismodule(obj)))
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
_sa_util.dependencies.resolve_all("sqlalchemy")
+
+
__go(locals())
diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py
index 65be4c7d5..209877e4a 100644
--- a/lib/sqlalchemy/connectors/mxodbc.py
+++ b/lib/sqlalchemy/connectors/mxodbc.py
@@ -27,7 +27,7 @@ from . import Connector
class MxODBCConnector(Connector):
- driver = 'mxodbc'
+ driver = "mxodbc"
supports_sane_multi_rowcount = False
supports_unicode_statements = True
@@ -41,12 +41,12 @@ class MxODBCConnector(Connector):
# attribute of the same name, so this is normally only called once.
cls._load_mx_exceptions()
platform = sys.platform
- if platform == 'win32':
+ if platform == "win32":
from mx.ODBC import Windows as Module
# this can be the string "linux2", and possibly others
- elif 'linux' in platform:
+ elif "linux" in platform:
from mx.ODBC import unixODBC as Module
- elif platform == 'darwin':
+ elif platform == "darwin":
from mx.ODBC import iODBC as Module
else:
raise ImportError("Unrecognized platform for mxODBC import")
@@ -68,6 +68,7 @@ class MxODBCConnector(Connector):
conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
conn.errorhandler = self._error_handler()
+
return connect
def _error_handler(self):
@@ -79,11 +80,12 @@ class MxODBCConnector(Connector):
def error_handler(connection, cursor, errorclass, errorvalue):
if issubclass(errorclass, MxOdbcWarning):
errorclass.__bases__ = (Warning,)
- warnings.warn(message=str(errorvalue),
- category=errorclass,
- stacklevel=2)
+ warnings.warn(
+ message=str(errorvalue), category=errorclass, stacklevel=2
+ )
else:
raise errorclass(errorvalue)
+
return error_handler
def create_connect_args(self, url):
@@ -101,11 +103,11 @@ class MxODBCConnector(Connector):
not be populated.
"""
- opts = url.translate_connect_args(username='user')
+ opts = url.translate_connect_args(username="user")
opts.update(url.query)
- args = opts.pop('host')
- opts.pop('port', None)
- opts.pop('database', None)
+ args = opts.pop("host")
+ opts.pop("port", None)
+ opts.pop("database", None)
return (args,), opts
def is_disconnect(self, e, connection, cursor):
@@ -114,7 +116,7 @@ class MxODBCConnector(Connector):
if isinstance(e, self.dbapi.ProgrammingError):
return "connection already closed" in str(e)
elif isinstance(e, self.dbapi.Error):
- return '[08S01]' in str(e)
+ return "[08S01]" in str(e)
else:
return False
@@ -123,7 +125,7 @@ class MxODBCConnector(Connector):
# of what we're doing here
dbapi_con = connection.connection
version = []
- r = re.compile(r'[.\-]')
+ r = re.compile(r"[.\-]")
# 18 == pyodbc.SQL_DBMS_VER
for n in r.split(dbapi_con.getinfo(18)[1]):
try:
@@ -134,8 +136,9 @@ class MxODBCConnector(Connector):
def _get_direct(self, context):
if context:
- native_odbc_execute = context.execution_options.\
- get('native_odbc_execute', 'auto')
+ native_odbc_execute = context.execution_options.get(
+ "native_odbc_execute", "auto"
+ )
# default to direct=True in all cases, is more generally
# compatible especially with SQL Server
return False if native_odbc_execute is True else True
@@ -144,8 +147,8 @@ class MxODBCConnector(Connector):
def do_executemany(self, cursor, statement, parameters, context=None):
cursor.executemany(
- statement, parameters, direct=self._get_direct(context))
+ statement, parameters, direct=self._get_direct(context)
+ )
def do_execute(self, cursor, statement, parameters, context=None):
- cursor.execute(statement, parameters,
- direct=self._get_direct(context))
+ cursor.execute(statement, parameters, direct=self._get_direct(context))
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
index 41ba89de6..8f5eea89b 100644
--- a/lib/sqlalchemy/connectors/pyodbc.py
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -13,7 +13,7 @@ import re
class PyODBCConnector(Connector):
- driver = 'pyodbc'
+ driver = "pyodbc"
supports_sane_rowcount_returning = False
supports_sane_multi_rowcount = False
@@ -22,7 +22,7 @@ class PyODBCConnector(Connector):
supports_unicode_binds = True
supports_native_decimal = True
- default_paramstyle = 'named'
+ default_paramstyle = "named"
# for non-DSN connections, this *may* be used to
# hold the desired driver name
@@ -35,10 +35,10 @@ class PyODBCConnector(Connector):
@classmethod
def dbapi(cls):
- return __import__('pyodbc')
+ return __import__("pyodbc")
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
+ opts = url.translate_connect_args(username="user")
opts.update(url.query)
keys = opts
@@ -46,52 +46,55 @@ class PyODBCConnector(Connector):
query = url.query
connect_args = {}
- for param in ('ansi', 'unicode_results', 'autocommit'):
+ for param in ("ansi", "unicode_results", "autocommit"):
if param in keys:
connect_args[param] = util.asbool(keys.pop(param))
- if 'odbc_connect' in keys:
- connectors = [util.unquote_plus(keys.pop('odbc_connect'))]
+ if "odbc_connect" in keys:
+ connectors = [util.unquote_plus(keys.pop("odbc_connect"))]
else:
+
def check_quote(token):
if ";" in str(token):
token = "'%s'" % token
return token
- keys = dict(
- (k, check_quote(v)) for k, v in keys.items()
- )
+ keys = dict((k, check_quote(v)) for k, v in keys.items())
- dsn_connection = 'dsn' in keys or \
- ('host' in keys and 'database' not in keys)
+ dsn_connection = "dsn" in keys or (
+ "host" in keys and "database" not in keys
+ )
if dsn_connection:
- connectors = ['dsn=%s' % (keys.pop('host', '') or
- keys.pop('dsn', ''))]
+ connectors = [
+ "dsn=%s" % (keys.pop("host", "") or keys.pop("dsn", ""))
+ ]
else:
- port = ''
- if 'port' in keys and 'port' not in query:
- port = ',%d' % int(keys.pop('port'))
+ port = ""
+ if "port" in keys and "port" not in query:
+ port = ",%d" % int(keys.pop("port"))
connectors = []
- driver = keys.pop('driver', self.pyodbc_driver_name)
+ driver = keys.pop("driver", self.pyodbc_driver_name)
if driver is None:
util.warn(
"No driver name specified; "
"this is expected by PyODBC when using "
- "DSN-less connections")
+ "DSN-less connections"
+ )
else:
connectors.append("DRIVER={%s}" % driver)
connectors.extend(
[
- 'Server=%s%s' % (keys.pop('host', ''), port),
- 'Database=%s' % keys.pop('database', '')
- ])
+ "Server=%s%s" % (keys.pop("host", ""), port),
+ "Database=%s" % keys.pop("database", ""),
+ ]
+ )
user = keys.pop("user", None)
if user:
connectors.append("UID=%s" % user)
- connectors.append("PWD=%s" % keys.pop('password', ''))
+ connectors.append("PWD=%s" % keys.pop("password", ""))
else:
connectors.append("Trusted_Connection=Yes")
@@ -99,18 +102,20 @@ class PyODBCConnector(Connector):
# convert textual data from your database encoding to your
# client encoding. This should obviously be set to 'No' if
# you query a cp1253 encoded database from a latin1 client...
- if 'odbc_autotranslate' in keys:
- connectors.append("AutoTranslate=%s" %
- keys.pop("odbc_autotranslate"))
+ if "odbc_autotranslate" in keys:
+ connectors.append(
+ "AutoTranslate=%s" % keys.pop("odbc_autotranslate")
+ )
- connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()])
+ connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()])
return [[";".join(connectors)], connect_args]
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError):
- return "The cursor's connection has been closed." in str(e) or \
- 'Attempt to use a closed connection.' in str(e)
+ return "The cursor's connection has been closed." in str(
+ e
+ ) or "Attempt to use a closed connection." in str(e)
else:
return False
@@ -123,10 +128,7 @@ class PyODBCConnector(Connector):
return self._parse_dbapi_version(self.dbapi.version)
def _parse_dbapi_version(self, vers):
- m = re.match(
- r'(?:py.*-)?([\d\.]+)(?:-(\w+))?',
- vers
- )
+ m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers)
if not m:
return ()
vers = tuple([int(x) for x in m.group(1).split(".")])
@@ -140,7 +142,7 @@ class PyODBCConnector(Connector):
# queries.
dbapi_con = connection.connection
version = []
- r = re.compile(r'[.\-]')
+ r = re.compile(r"[.\-]")
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
try:
version.append(int(n))
@@ -153,12 +155,11 @@ class PyODBCConnector(Connector):
# adjust for ConnectionFairy being present
# allows attribute set e.g. "connection.autocommit = True"
# to work properly
- if hasattr(connection, 'connection'):
+ if hasattr(connection, "connection"):
connection = connection.connection
- if level == 'AUTOCOMMIT':
+ if level == "AUTOCOMMIT":
connection.autocommit = True
else:
connection.autocommit = False
- super(PyODBCConnector, self).set_isolation_level(connection,
- level)
+ super(PyODBCConnector, self).set_isolation_level(connection, level)
diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py
index 71decd9ab..003ecbed1 100644
--- a/lib/sqlalchemy/connectors/zxJDBC.py
+++ b/lib/sqlalchemy/connectors/zxJDBC.py
@@ -10,15 +10,15 @@ from . import Connector
class ZxJDBCConnector(Connector):
- driver = 'zxjdbc'
+ driver = "zxjdbc"
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_unicode_binds = True
- supports_unicode_statements = sys.version > '2.5.0+'
+ supports_unicode_statements = sys.version > "2.5.0+"
description_encoding = None
- default_paramstyle = 'qmark'
+ default_paramstyle = "qmark"
jdbc_db_name = None
jdbc_driver_name = None
@@ -26,6 +26,7 @@ class ZxJDBCConnector(Connector):
@classmethod
def dbapi(cls):
from com.ziclix.python.sql import zxJDBC
+
return zxJDBC
def _driver_kwargs(self):
@@ -34,25 +35,31 @@ class ZxJDBCConnector(Connector):
def _create_jdbc_url(self, url):
"""Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`"""
- return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host,
- url.port is not None
- and ':%s' % url.port or '',
- url.database)
+ return "jdbc:%s://%s%s/%s" % (
+ self.jdbc_db_name,
+ url.host,
+ url.port is not None and ":%s" % url.port or "",
+ url.database,
+ )
def create_connect_args(self, url):
opts = self._driver_kwargs()
opts.update(url.query)
return [
- [self._create_jdbc_url(url),
- url.username, url.password,
- self.jdbc_driver_name],
- opts]
+ [
+ self._create_jdbc_url(url),
+ url.username,
+ url.password,
+ self.jdbc_driver_name,
+ ],
+ opts,
+ ]
def is_disconnect(self, e, connection, cursor):
if not isinstance(e, self.dbapi.ProgrammingError):
return False
e = str(e)
- return 'connection is closed' in e or 'cursor is closed' in e
+ return "connection is closed" in e or "cursor is closed" in e
def _get_server_version_info(self, connection):
# use connection.connection.dbversion, and parse appropriately
diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py
index 2cb252737..d2d56a7ae 100644
--- a/lib/sqlalchemy/databases/__init__.py
+++ b/lib/sqlalchemy/databases/__init__.py
@@ -11,6 +11,7 @@ compatibility with pre 0.6 versions.
"""
from ..dialects.sqlite import base as sqlite
from ..dialects.postgresql import base as postgresql
+
postgres = postgresql
from ..dialects.mysql import base as mysql
from ..dialects.oracle import base as oracle
@@ -20,11 +21,11 @@ from ..dialects.sybase import base as sybase
__all__ = (
- 'firebird',
- 'mssql',
- 'mysql',
- 'postgresql',
- 'sqlite',
- 'oracle',
- 'sybase',
+ "firebird",
+ "mssql",
+ "mysql",
+ "postgresql",
+ "sqlite",
+ "oracle",
+ "sybase",
)
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py
index 963babcb8..65f30bb76 100644
--- a/lib/sqlalchemy/dialects/__init__.py
+++ b/lib/sqlalchemy/dialects/__init__.py
@@ -6,18 +6,19 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
__all__ = (
- 'firebird',
- 'mssql',
- 'mysql',
- 'oracle',
- 'postgresql',
- 'sqlite',
- 'sybase',
+ "firebird",
+ "mssql",
+ "mysql",
+ "oracle",
+ "postgresql",
+ "sqlite",
+ "sybase",
)
from .. import util
-_translates = {'postgres': 'postgresql'}
+_translates = {"postgres": "postgresql"}
+
def _auto_fn(name):
"""default dialect importer.
@@ -40,7 +41,7 @@ def _auto_fn(name):
)
dialect = translated
try:
- module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
+ module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects
except ImportError:
return None
@@ -51,6 +52,7 @@ def _auto_fn(name):
else:
return None
+
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
-plugins = util.PluginLoader("sqlalchemy.plugins") \ No newline at end of file
+plugins = util.PluginLoader("sqlalchemy.plugins")
diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py
index c83db453b..510d62337 100644
--- a/lib/sqlalchemy/dialects/firebird/__init__.py
+++ b/lib/sqlalchemy/dialects/firebird/__init__.py
@@ -7,14 +7,35 @@
from . import base, kinterbasdb, fdb # noqa
-from sqlalchemy.dialects.firebird.base import \
- SMALLINT, BIGINT, FLOAT, DATE, TIME, \
- TEXT, NUMERIC, TIMESTAMP, VARCHAR, CHAR, BLOB
+from sqlalchemy.dialects.firebird.base import (
+ SMALLINT,
+ BIGINT,
+ FLOAT,
+ DATE,
+ TIME,
+ TEXT,
+ NUMERIC,
+ TIMESTAMP,
+ VARCHAR,
+ CHAR,
+ BLOB,
+)
base.dialect = dialect = fdb.dialect
__all__ = (
- 'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME',
- 'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB',
- 'dialect'
+ "SMALLINT",
+ "BIGINT",
+ "FLOAT",
+ "FLOAT",
+ "DATE",
+ "TIME",
+ "TEXT",
+ "NUMERIC",
+ "FLOAT",
+ "TIMESTAMP",
+ "VARCHAR",
+ "CHAR",
+ "BLOB",
+ "dialect",
)
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
index 7b470c189..1e9c778f3 100644
--- a/lib/sqlalchemy/dialects/firebird/base.py
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -79,48 +79,254 @@ from sqlalchemy.engine import base, default, reflection
from sqlalchemy.sql import compiler
from sqlalchemy.sql.elements import quoted_name
-from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC,
- SMALLINT, TEXT, TIME, TIMESTAMP, Integer)
-
-
-RESERVED_WORDS = set([
- "active", "add", "admin", "after", "all", "alter", "and", "any", "as",
- "asc", "ascending", "at", "auto", "avg", "before", "begin", "between",
- "bigint", "bit_length", "blob", "both", "by", "case", "cast", "char",
- "character", "character_length", "char_length", "check", "close",
- "collate", "column", "commit", "committed", "computed", "conditional",
- "connect", "constraint", "containing", "count", "create", "cross",
- "cstring", "current", "current_connection", "current_date",
- "current_role", "current_time", "current_timestamp",
- "current_transaction", "current_user", "cursor", "database", "date",
- "day", "dec", "decimal", "declare", "default", "delete", "desc",
- "descending", "disconnect", "distinct", "do", "domain", "double",
- "drop", "else", "end", "entry_point", "escape", "exception",
- "execute", "exists", "exit", "external", "extract", "fetch", "file",
- "filter", "float", "for", "foreign", "from", "full", "function",
- "gdscode", "generator", "gen_id", "global", "grant", "group",
- "having", "hour", "if", "in", "inactive", "index", "inner",
- "input_type", "insensitive", "insert", "int", "integer", "into", "is",
- "isolation", "join", "key", "leading", "left", "length", "level",
- "like", "long", "lower", "manual", "max", "maximum_segment", "merge",
- "min", "minute", "module_name", "month", "names", "national",
- "natural", "nchar", "no", "not", "null", "numeric", "octet_length",
- "of", "on", "only", "open", "option", "or", "order", "outer",
- "output_type", "overflow", "page", "pages", "page_size", "parameter",
- "password", "plan", "position", "post_event", "precision", "primary",
- "privileges", "procedure", "protected", "rdb$db_key", "read", "real",
- "record_version", "recreate", "recursive", "references", "release",
- "reserv", "reserving", "retain", "returning_values", "returns",
- "revoke", "right", "rollback", "rows", "row_count", "savepoint",
- "schema", "second", "segment", "select", "sensitive", "set", "shadow",
- "shared", "singular", "size", "smallint", "snapshot", "some", "sort",
- "sqlcode", "stability", "start", "starting", "starts", "statistics",
- "sub_type", "sum", "suspend", "table", "then", "time", "timestamp",
- "to", "trailing", "transaction", "trigger", "trim", "uncommitted",
- "union", "unique", "update", "upper", "user", "using", "value",
- "values", "varchar", "variable", "varying", "view", "wait", "when",
- "where", "while", "with", "work", "write", "year",
-])
+from sqlalchemy.types import (
+ BIGINT,
+ BLOB,
+ DATE,
+ FLOAT,
+ INTEGER,
+ NUMERIC,
+ SMALLINT,
+ TEXT,
+ TIME,
+ TIMESTAMP,
+ Integer,
+)
+
+
+RESERVED_WORDS = set(
+ [
+ "active",
+ "add",
+ "admin",
+ "after",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "ascending",
+ "at",
+ "auto",
+ "avg",
+ "before",
+ "begin",
+ "between",
+ "bigint",
+ "bit_length",
+ "blob",
+ "both",
+ "by",
+ "case",
+ "cast",
+ "char",
+ "character",
+ "character_length",
+ "char_length",
+ "check",
+ "close",
+ "collate",
+ "column",
+ "commit",
+ "committed",
+ "computed",
+ "conditional",
+ "connect",
+ "constraint",
+ "containing",
+ "count",
+ "create",
+ "cross",
+ "cstring",
+ "current",
+ "current_connection",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_transaction",
+ "current_user",
+ "cursor",
+ "database",
+ "date",
+ "day",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delete",
+ "desc",
+ "descending",
+ "disconnect",
+ "distinct",
+ "do",
+ "domain",
+ "double",
+ "drop",
+ "else",
+ "end",
+ "entry_point",
+ "escape",
+ "exception",
+ "execute",
+ "exists",
+ "exit",
+ "external",
+ "extract",
+ "fetch",
+ "file",
+ "filter",
+ "float",
+ "for",
+ "foreign",
+ "from",
+ "full",
+ "function",
+ "gdscode",
+ "generator",
+ "gen_id",
+ "global",
+ "grant",
+ "group",
+ "having",
+ "hour",
+ "if",
+ "in",
+ "inactive",
+ "index",
+ "inner",
+ "input_type",
+ "insensitive",
+ "insert",
+ "int",
+ "integer",
+ "into",
+ "is",
+ "isolation",
+ "join",
+ "key",
+ "leading",
+ "left",
+ "length",
+ "level",
+ "like",
+ "long",
+ "lower",
+ "manual",
+ "max",
+ "maximum_segment",
+ "merge",
+ "min",
+ "minute",
+ "module_name",
+ "month",
+ "names",
+ "national",
+ "natural",
+ "nchar",
+ "no",
+ "not",
+ "null",
+ "numeric",
+ "octet_length",
+ "of",
+ "on",
+ "only",
+ "open",
+ "option",
+ "or",
+ "order",
+ "outer",
+ "output_type",
+ "overflow",
+ "page",
+ "pages",
+ "page_size",
+ "parameter",
+ "password",
+ "plan",
+ "position",
+ "post_event",
+ "precision",
+ "primary",
+ "privileges",
+ "procedure",
+ "protected",
+ "rdb$db_key",
+ "read",
+ "real",
+ "record_version",
+ "recreate",
+ "recursive",
+ "references",
+ "release",
+ "reserv",
+ "reserving",
+ "retain",
+ "returning_values",
+ "returns",
+ "revoke",
+ "right",
+ "rollback",
+ "rows",
+ "row_count",
+ "savepoint",
+ "schema",
+ "second",
+ "segment",
+ "select",
+ "sensitive",
+ "set",
+ "shadow",
+ "shared",
+ "singular",
+ "size",
+ "smallint",
+ "snapshot",
+ "some",
+ "sort",
+ "sqlcode",
+ "stability",
+ "start",
+ "starting",
+ "starts",
+ "statistics",
+ "sub_type",
+ "sum",
+ "suspend",
+ "table",
+ "then",
+ "time",
+ "timestamp",
+ "to",
+ "trailing",
+ "transaction",
+ "trigger",
+ "trim",
+ "uncommitted",
+ "union",
+ "unique",
+ "update",
+ "upper",
+ "user",
+ "using",
+ "value",
+ "values",
+ "varchar",
+ "variable",
+ "varying",
+ "view",
+ "wait",
+ "when",
+ "where",
+ "while",
+ "with",
+ "work",
+ "write",
+ "year",
+ ]
+)
class _StringType(sqltypes.String):
@@ -133,7 +339,8 @@ class _StringType(sqltypes.String):
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""Firebird VARCHAR type"""
- __visit_name__ = 'VARCHAR'
+
+ __visit_name__ = "VARCHAR"
def __init__(self, length=None, **kwargs):
super(VARCHAR, self).__init__(length=length, **kwargs)
@@ -141,7 +348,8 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
class CHAR(_StringType, sqltypes.CHAR):
"""Firebird CHAR type"""
- __visit_name__ = 'CHAR'
+
+ __visit_name__ = "CHAR"
def __init__(self, length=None, **kwargs):
super(CHAR, self).__init__(length=length, **kwargs)
@@ -154,32 +362,33 @@ class _FBDateTime(sqltypes.DateTime):
return datetime.datetime(value.year, value.month, value.day)
else:
return value
+
return process
-colspecs = {
- sqltypes.DateTime: _FBDateTime
-}
+
+colspecs = {sqltypes.DateTime: _FBDateTime}
ischema_names = {
- 'SHORT': SMALLINT,
- 'LONG': INTEGER,
- 'QUAD': FLOAT,
- 'FLOAT': FLOAT,
- 'DATE': DATE,
- 'TIME': TIME,
- 'TEXT': TEXT,
- 'INT64': BIGINT,
- 'DOUBLE': FLOAT,
- 'TIMESTAMP': TIMESTAMP,
- 'VARYING': VARCHAR,
- 'CSTRING': CHAR,
- 'BLOB': BLOB,
+ "SHORT": SMALLINT,
+ "LONG": INTEGER,
+ "QUAD": FLOAT,
+ "FLOAT": FLOAT,
+ "DATE": DATE,
+ "TIME": TIME,
+ "TEXT": TEXT,
+ "INT64": BIGINT,
+ "DOUBLE": FLOAT,
+ "TIMESTAMP": TIMESTAMP,
+ "VARYING": VARCHAR,
+ "CSTRING": CHAR,
+ "BLOB": BLOB,
}
# TODO: date conversion types (should be implemented as _FBDateTime,
# _FBDate, etc. as bind/result functionality is required)
+
class FBTypeCompiler(compiler.GenericTypeCompiler):
def visit_boolean(self, type_, **kw):
return self.visit_SMALLINT(type_, **kw)
@@ -194,11 +403,11 @@ class FBTypeCompiler(compiler.GenericTypeCompiler):
return "BLOB SUB_TYPE 0"
def _extend_string(self, type_, basic):
- charset = getattr(type_, 'charset', None)
+ charset = getattr(type_, "charset", None)
if charset is None:
return basic
else:
- return '%s CHARACTER SET %s' % (basic, charset)
+ return "%s CHARACTER SET %s" % (basic, charset)
def visit_CHAR(self, type_, **kw):
basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw)
@@ -207,8 +416,8 @@ class FBTypeCompiler(compiler.GenericTypeCompiler):
def visit_VARCHAR(self, type_, **kw):
if not type_.length:
raise exc.CompileError(
- "VARCHAR requires a length on dialect %s" %
- self.dialect.name)
+ "VARCHAR requires a length on dialect %s" % self.dialect.name
+ )
basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw)
return self._extend_string(type_, basic)
@@ -228,36 +437,42 @@ class FBCompiler(sql.compiler.SQLCompiler):
return "CURRENT_TIMESTAMP"
def visit_startswith_op_binary(self, binary, operator, **kw):
- return '%s STARTING WITH %s' % (
+ return "%s STARTING WITH %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw))
+ binary.right._compiler_dispatch(self, **kw),
+ )
def visit_notstartswith_op_binary(self, binary, operator, **kw):
- return '%s NOT STARTING WITH %s' % (
+ return "%s NOT STARTING WITH %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw))
+ binary.right._compiler_dispatch(self, **kw),
+ )
def visit_mod_binary(self, binary, operator, **kw):
return "mod(%s, %s)" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def visit_alias(self, alias, asfrom=False, **kwargs):
if self.dialect._version_two:
- return super(FBCompiler, self).\
- visit_alias(alias, asfrom=asfrom, **kwargs)
+ return super(FBCompiler, self).visit_alias(
+ alias, asfrom=asfrom, **kwargs
+ )
else:
# Override to not use the AS keyword which FB 1.5 does not like
if asfrom:
- alias_name = isinstance(alias.name,
- expression._truncated_label) and \
- self._truncated_identifier("alias",
- alias.name) or alias.name
-
- return self.process(
- alias.original, asfrom=asfrom, **kwargs) + \
- " " + \
- self.preparer.format_alias(alias, alias_name)
+ alias_name = (
+ isinstance(alias.name, expression._truncated_label)
+ and self._truncated_identifier("alias", alias.name)
+ or alias.name
+ )
+
+ return (
+ self.process(alias.original, asfrom=asfrom, **kwargs)
+ + " "
+ + self.preparer.format_alias(alias, alias_name)
+ )
else:
return self.process(alias.original, **kwargs)
@@ -320,7 +535,7 @@ class FBCompiler(sql.compiler.SQLCompiler):
for c in expression._select_iterables(returning_cols)
]
- return 'RETURNING ' + ', '.join(columns)
+ return "RETURNING " + ", ".join(columns)
class FBDDLCompiler(sql.compiler.DDLCompiler):
@@ -333,27 +548,33 @@ class FBDDLCompiler(sql.compiler.DDLCompiler):
# http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
if create.element.start is not None:
raise NotImplemented(
- "Firebird SEQUENCE doesn't support START WITH")
+ "Firebird SEQUENCE doesn't support START WITH"
+ )
if create.element.increment is not None:
raise NotImplemented(
- "Firebird SEQUENCE doesn't support INCREMENT BY")
+ "Firebird SEQUENCE doesn't support INCREMENT BY"
+ )
if self.dialect._version_two:
- return "CREATE SEQUENCE %s" % \
- self.preparer.format_sequence(create.element)
+ return "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
else:
- return "CREATE GENERATOR %s" % \
- self.preparer.format_sequence(create.element)
+ return "CREATE GENERATOR %s" % self.preparer.format_sequence(
+ create.element
+ )
def visit_drop_sequence(self, drop):
"""Generate a ``DROP GENERATOR`` statement for the sequence."""
if self.dialect._version_two:
- return "DROP SEQUENCE %s" % \
- self.preparer.format_sequence(drop.element)
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(
+ drop.element
+ )
else:
- return "DROP GENERATOR %s" % \
- self.preparer.format_sequence(drop.element)
+ return "DROP GENERATOR %s" % self.preparer.format_sequence(
+ drop.element
+ )
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
@@ -361,7 +582,8 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(
- ['_'])
+ ["_"]
+ )
def __init__(self, dialect):
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
@@ -372,16 +594,16 @@ class FBExecutionContext(default.DefaultExecutionContext):
"""Get the next value from the sequence using ``gen_id()``."""
return self._execute_scalar(
- "SELECT gen_id(%s, 1) FROM rdb$database" %
- self.dialect.identifier_preparer.format_sequence(seq),
- type_
+ "SELECT gen_id(%s, 1) FROM rdb$database"
+ % self.dialect.identifier_preparer.format_sequence(seq),
+ type_,
)
class FBDialect(default.DefaultDialect):
"""Firebird dialect"""
- name = 'firebird'
+ name = "firebird"
max_identifier_length = 31
@@ -413,23 +635,23 @@ class FBDialect(default.DefaultDialect):
def initialize(self, connection):
super(FBDialect, self).initialize(connection)
- self._version_two = ('firebird' in self.server_version_info and
- self.server_version_info >= (2, )
- ) or \
- ('interbase' in self.server_version_info and
- self.server_version_info >= (6, )
- )
+ self._version_two = (
+ "firebird" in self.server_version_info
+ and self.server_version_info >= (2,)
+ ) or (
+ "interbase" in self.server_version_info
+ and self.server_version_info >= (6,)
+ )
if not self._version_two:
# TODO: whatever other pre < 2.0 stuff goes here
self.ischema_names = ischema_names.copy()
- self.ischema_names['TIMESTAMP'] = sqltypes.DATE
- self.colspecs = {
- sqltypes.DateTime: sqltypes.DATE
- }
+ self.ischema_names["TIMESTAMP"] = sqltypes.DATE
+ self.colspecs = {sqltypes.DateTime: sqltypes.DATE}
- self.implicit_returning = self._version_two and \
- self.__dict__.get('implicit_returning', True)
+ self.implicit_returning = self._version_two and self.__dict__.get(
+ "implicit_returning", True
+ )
def normalize_name(self, name):
# Remove trailing spaces: FB uses a CHAR() type,
@@ -437,8 +659,9 @@ class FBDialect(default.DefaultDialect):
name = name and name.rstrip()
if name is None:
return None
- elif name.upper() == name and \
- not self.identifier_preparer._requires_quotes(name.lower()):
+ elif name.upper() == name and not self.identifier_preparer._requires_quotes(
+ name.lower()
+ ):
return name.lower()
elif name.lower() == name:
return quoted_name(name, quote=True)
@@ -448,8 +671,9 @@ class FBDialect(default.DefaultDialect):
def denormalize_name(self, name):
if name is None:
return None
- elif name.lower() == name and \
- not self.identifier_preparer._requires_quotes(name.lower()):
+ elif name.lower() == name and not self.identifier_preparer._requires_quotes(
+ name.lower()
+ ):
return name.upper()
else:
return name
@@ -522,7 +746,7 @@ class FBDialect(default.DefaultDialect):
rp = connection.execute(qry, [self.denormalize_name(view_name)])
row = rp.first()
if row:
- return row['view_source']
+ return row["view_source"]
else:
return None
@@ -538,13 +762,13 @@ class FBDialect(default.DefaultDialect):
tablename = self.denormalize_name(table_name)
# get primary key fields
c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
- pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
- return {'constrained_columns': pkfields, 'name': None}
+ pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()]
+ return {"constrained_columns": pkfields, "name": None}
@reflection.cache
- def get_column_sequence(self, connection,
- table_name, column_name,
- schema=None, **kw):
+ def get_column_sequence(
+ self, connection, table_name, column_name, schema=None, **kw
+ ):
tablename = self.denormalize_name(table_name)
colname = self.denormalize_name(column_name)
# Heuristic-query to determine the generator associated to a PK field
@@ -567,7 +791,7 @@ class FBDialect(default.DefaultDialect):
"""
genr = connection.execute(genqry, [tablename, colname]).first()
if genr is not None:
- return dict(name=self.normalize_name(genr['fgenerator']))
+ return dict(name=self.normalize_name(genr["fgenerator"]))
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
@@ -595,7 +819,7 @@ class FBDialect(default.DefaultDialect):
"""
# get the PK, used to determine the eventual associated sequence
pk_constraint = self.get_pk_constraint(connection, table_name)
- pkey_cols = pk_constraint['constrained_columns']
+ pkey_cols = pk_constraint["constrained_columns"]
tablename = self.denormalize_name(table_name)
# get all of the fields for this table
@@ -605,26 +829,28 @@ class FBDialect(default.DefaultDialect):
row = c.fetchone()
if row is None:
break
- name = self.normalize_name(row['fname'])
- orig_colname = row['fname']
+ name = self.normalize_name(row["fname"])
+ orig_colname = row["fname"]
# get the data type
- colspec = row['ftype'].rstrip()
+ colspec = row["ftype"].rstrip()
coltype = self.ischema_names.get(colspec)
if coltype is None:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (colspec, name))
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (colspec, name)
+ )
coltype = sqltypes.NULLTYPE
- elif issubclass(coltype, Integer) and row['fprec'] != 0:
+ elif issubclass(coltype, Integer) and row["fprec"] != 0:
coltype = NUMERIC(
- precision=row['fprec'],
- scale=row['fscale'] * -1)
- elif colspec in ('VARYING', 'CSTRING'):
- coltype = coltype(row['flen'])
- elif colspec == 'TEXT':
- coltype = TEXT(row['flen'])
- elif colspec == 'BLOB':
- if row['stype'] == 1:
+ precision=row["fprec"], scale=row["fscale"] * -1
+ )
+ elif colspec in ("VARYING", "CSTRING"):
+ coltype = coltype(row["flen"])
+ elif colspec == "TEXT":
+ coltype = TEXT(row["flen"])
+ elif colspec == "BLOB":
+ if row["stype"] == 1:
coltype = TEXT()
else:
coltype = BLOB()
@@ -633,36 +859,36 @@ class FBDialect(default.DefaultDialect):
# does it have a default value?
defvalue = None
- if row['fdefault'] is not None:
+ if row["fdefault"] is not None:
# the value comes down as "DEFAULT 'value'": there may be
# more than one whitespace around the "DEFAULT" keyword
# and it may also be lower case
# (see also http://tracker.firebirdsql.org/browse/CORE-356)
- defexpr = row['fdefault'].lstrip()
- assert defexpr[:8].rstrip().upper() == \
- 'DEFAULT', "Unrecognized default value: %s" % \
- defexpr
+ defexpr = row["fdefault"].lstrip()
+ assert defexpr[:8].rstrip().upper() == "DEFAULT", (
+ "Unrecognized default value: %s" % defexpr
+ )
defvalue = defexpr[8:].strip()
- if defvalue == 'NULL':
+ if defvalue == "NULL":
# Redundant
defvalue = None
col_d = {
- 'name': name,
- 'type': coltype,
- 'nullable': not bool(row['null_flag']),
- 'default': defvalue,
- 'autoincrement': 'auto',
+ "name": name,
+ "type": coltype,
+ "nullable": not bool(row["null_flag"]),
+ "default": defvalue,
+ "autoincrement": "auto",
}
if orig_colname.lower() == orig_colname:
- col_d['quote'] = True
+ col_d["quote"] = True
# if the PK is a single field, try to see if its linked to
# a sequence thru a trigger
if len(pkey_cols) == 1 and name == pkey_cols[0]:
seq_d = self.get_column_sequence(connection, tablename, name)
if seq_d is not None:
- col_d['sequence'] = seq_d
+ col_d["sequence"] = seq_d
cols.append(col_d)
return cols
@@ -689,24 +915,26 @@ class FBDialect(default.DefaultDialect):
tablename = self.denormalize_name(table_name)
c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
- fks = util.defaultdict(lambda: {
- 'name': None,
- 'constrained_columns': [],
- 'referred_schema': None,
- 'referred_table': None,
- 'referred_columns': []
- })
+ fks = util.defaultdict(
+ lambda: {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
+ }
+ )
for row in c:
- cname = self.normalize_name(row['cname'])
+ cname = self.normalize_name(row["cname"])
fk = fks[cname]
- if not fk['name']:
- fk['name'] = cname
- fk['referred_table'] = self.normalize_name(row['targetrname'])
- fk['constrained_columns'].append(
- self.normalize_name(row['fname']))
- fk['referred_columns'].append(
- self.normalize_name(row['targetfname']))
+ if not fk["name"]:
+ fk["name"] = cname
+ fk["referred_table"] = self.normalize_name(row["targetrname"])
+ fk["constrained_columns"].append(self.normalize_name(row["fname"]))
+ fk["referred_columns"].append(
+ self.normalize_name(row["targetfname"])
+ )
return list(fks.values())
@reflection.cache
@@ -729,13 +957,14 @@ class FBDialect(default.DefaultDialect):
indexes = util.defaultdict(dict)
for row in c:
- indexrec = indexes[row['index_name']]
- if 'name' not in indexrec:
- indexrec['name'] = self.normalize_name(row['index_name'])
- indexrec['column_names'] = []
- indexrec['unique'] = bool(row['unique_flag'])
-
- indexrec['column_names'].append(
- self.normalize_name(row['field_name']))
+ indexrec = indexes[row["index_name"]]
+ if "name" not in indexrec:
+ indexrec["name"] = self.normalize_name(row["index_name"])
+ indexrec["column_names"] = []
+ indexrec["unique"] = bool(row["unique_flag"])
+
+ indexrec["column_names"].append(
+ self.normalize_name(row["field_name"])
+ )
return list(indexes.values())
diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py
index e8da6e1b7..5bf3d2c49 100644
--- a/lib/sqlalchemy/dialects/firebird/fdb.py
+++ b/lib/sqlalchemy/dialects/firebird/fdb.py
@@ -73,25 +73,23 @@ from ... import util
class FBDialect_fdb(FBDialect_kinterbasdb):
-
- def __init__(self, enable_rowcount=True,
- retaining=False, **kwargs):
+ def __init__(self, enable_rowcount=True, retaining=False, **kwargs):
super(FBDialect_fdb, self).__init__(
- enable_rowcount=enable_rowcount,
- retaining=retaining, **kwargs)
+ enable_rowcount=enable_rowcount, retaining=retaining, **kwargs
+ )
@classmethod
def dbapi(cls):
- return __import__('fdb')
+ return __import__("fdb")
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if opts.get('port'):
- opts['host'] = "%s/%s" % (opts['host'], opts['port'])
- del opts['port']
+ opts = url.translate_connect_args(username="user")
+ if opts.get("port"):
+ opts["host"] = "%s/%s" % (opts["host"], opts["port"])
+ del opts["port"]
opts.update(url.query)
- util.coerce_kw_type(opts, 'type_conv', int)
+ util.coerce_kw_type(opts, "type_conv", int)
return ([], opts)
@@ -115,4 +113,5 @@ class FBDialect_fdb(FBDialect_kinterbasdb):
return self._parse_version_info(version)
+
dialect = FBDialect_fdb
diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
index dc88fc849..6d7144096 100644
--- a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
+++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
@@ -51,6 +51,7 @@ class _kinterbasdb_numeric(object):
return str(value)
else:
return value
+
return process
@@ -65,15 +66,16 @@ class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
class FBExecutionContext_kinterbasdb(FBExecutionContext):
@property
def rowcount(self):
- if self.execution_options.get('enable_rowcount',
- self.dialect.enable_rowcount):
+ if self.execution_options.get(
+ "enable_rowcount", self.dialect.enable_rowcount
+ ):
return self.cursor.rowcount
else:
return -1
class FBDialect_kinterbasdb(FBDialect):
- driver = 'kinterbasdb'
+ driver = "kinterbasdb"
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
execution_ctx_cls = FBExecutionContext_kinterbasdb
@@ -85,13 +87,17 @@ class FBDialect_kinterbasdb(FBDialect):
{
sqltypes.Numeric: _FBNumeric_kinterbasdb,
sqltypes.Float: _FBFloat_kinterbasdb,
- }
-
+ },
)
- def __init__(self, type_conv=200, concurrency_level=1,
- enable_rowcount=True,
- retaining=False, **kwargs):
+ def __init__(
+ self,
+ type_conv=200,
+ concurrency_level=1,
+ enable_rowcount=True,
+ retaining=False,
+ **kwargs
+ ):
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
self.enable_rowcount = enable_rowcount
self.type_conv = type_conv
@@ -102,7 +108,7 @@ class FBDialect_kinterbasdb(FBDialect):
@classmethod
def dbapi(cls):
- return __import__('kinterbasdb')
+ return __import__("kinterbasdb")
def do_execute(self, cursor, statement, parameters, context=None):
# kinterbase does not accept a None, but wants an empty list
@@ -116,28 +122,30 @@ class FBDialect_kinterbasdb(FBDialect):
dbapi_connection.commit(self.retaining)
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if opts.get('port'):
- opts['host'] = "%s/%s" % (opts['host'], opts['port'])
- del opts['port']
+ opts = url.translate_connect_args(username="user")
+ if opts.get("port"):
+ opts["host"] = "%s/%s" % (opts["host"], opts["port"])
+ del opts["port"]
opts.update(url.query)
- util.coerce_kw_type(opts, 'type_conv', int)
+ util.coerce_kw_type(opts, "type_conv", int)
- type_conv = opts.pop('type_conv', self.type_conv)
- concurrency_level = opts.pop('concurrency_level',
- self.concurrency_level)
+ type_conv = opts.pop("type_conv", self.type_conv)
+ concurrency_level = opts.pop(
+ "concurrency_level", self.concurrency_level
+ )
if self.dbapi is not None:
- initialized = getattr(self.dbapi, 'initialized', None)
+ initialized = getattr(self.dbapi, "initialized", None)
if initialized is None:
# CVS rev 1.96 changed the name of the attribute:
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
# Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
- initialized = getattr(self.dbapi, '_initialized', False)
+ initialized = getattr(self.dbapi, "_initialized", False)
if not initialized:
- self.dbapi.init(type_conv=type_conv,
- concurrency_level=concurrency_level)
+ self.dbapi.init(
+ type_conv=type_conv, concurrency_level=concurrency_level
+ )
return ([], opts)
def _get_server_version_info(self, connection):
@@ -160,25 +168,31 @@ class FBDialect_kinterbasdb(FBDialect):
def _parse_version_info(self, version):
m = match(
- r'\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version)
+ r"\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?", version
+ )
if not m:
raise AssertionError(
- "Could not determine version from string '%s'" % version)
+ "Could not determine version from string '%s'" % version
+ )
if m.group(5) != None:
- return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird'])
+ return tuple([int(x) for x in m.group(6, 7, 4)] + ["firebird"])
else:
- return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase'])
+ return tuple([int(x) for x in m.group(1, 2, 3)] + ["interbase"])
def is_disconnect(self, e, connection, cursor):
- if isinstance(e, (self.dbapi.OperationalError,
- self.dbapi.ProgrammingError)):
+ if isinstance(
+ e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
+ ):
msg = str(e)
- return ('Unable to complete network request to host' in msg or
- 'Invalid connection state' in msg or
- 'Invalid cursor state' in msg or
- 'connection shutdown' in msg)
+ return (
+ "Unable to complete network request to host" in msg
+ or "Invalid connection state" in msg
+ or "Invalid cursor state" in msg
+ or "connection shutdown" in msg
+ )
else:
return False
+
dialect = FBDialect_kinterbasdb
diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py
index 9c861e89d..88a94fcfb 100644
--- a/lib/sqlalchemy/dialects/mssql/__init__.py
+++ b/lib/sqlalchemy/dialects/mssql/__init__.py
@@ -7,20 +7,74 @@
from . import base, pyodbc, adodbapi, pymssql, zxjdbc, mxodbc # noqa
-from .base import \
- INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \
- NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\
- DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \
- BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP, ROWVERSION, \
- MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, XML
+from .base import (
+ INTEGER,
+ BIGINT,
+ SMALLINT,
+ TINYINT,
+ VARCHAR,
+ NVARCHAR,
+ CHAR,
+ NCHAR,
+ TEXT,
+ NTEXT,
+ DECIMAL,
+ NUMERIC,
+ FLOAT,
+ DATETIME,
+ DATETIME2,
+ DATETIMEOFFSET,
+ DATE,
+ TIME,
+ SMALLDATETIME,
+ BINARY,
+ VARBINARY,
+ BIT,
+ REAL,
+ IMAGE,
+ TIMESTAMP,
+ ROWVERSION,
+ MONEY,
+ SMALLMONEY,
+ UNIQUEIDENTIFIER,
+ SQL_VARIANT,
+ XML,
+)
base.dialect = dialect = pyodbc.dialect
__all__ = (
- 'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR',
- 'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME',
- 'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME',
- 'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP', 'ROWVERSION',
- 'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'XML', 'dialect'
+ "INTEGER",
+ "BIGINT",
+ "SMALLINT",
+ "TINYINT",
+ "VARCHAR",
+ "NVARCHAR",
+ "CHAR",
+ "NCHAR",
+ "TEXT",
+ "NTEXT",
+ "DECIMAL",
+ "NUMERIC",
+ "FLOAT",
+ "DATETIME",
+ "DATETIME2",
+ "DATETIMEOFFSET",
+ "DATE",
+ "TIME",
+ "SMALLDATETIME",
+ "BINARY",
+ "VARBINARY",
+ "BIT",
+ "REAL",
+ "IMAGE",
+ "TIMESTAMP",
+ "ROWVERSION",
+ "MONEY",
+ "SMALLMONEY",
+ "UNIQUEIDENTIFIER",
+ "SQL_VARIANT",
+ "XML",
+ "dialect",
)
diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py
index e5bb9ba57..d985c3bb6 100644
--- a/lib/sqlalchemy/dialects/mssql/adodbapi.py
+++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py
@@ -33,6 +33,7 @@ class MSDateTime_adodbapi(MSDateTime):
if type(value) is datetime.date:
return datetime.datetime(value.year, value.month, value.day)
return value
+
return process
@@ -41,18 +42,16 @@ class MSDialect_adodbapi(MSDialect):
supports_sane_multi_rowcount = True
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = True
- driver = 'adodbapi'
+ driver = "adodbapi"
@classmethod
def import_dbapi(cls):
import adodbapi as module
+
return module
colspecs = util.update_copy(
- MSDialect.colspecs,
- {
- sqltypes.DateTime: MSDateTime_adodbapi
- }
+ MSDialect.colspecs, {sqltypes.DateTime: MSDateTime_adodbapi}
)
def create_connect_args(self, url):
@@ -61,14 +60,13 @@ class MSDialect_adodbapi(MSDialect):
token = "'%s'" % token
return token
- keys = dict(
- (k, check_quote(v)) for k, v in url.query.items()
- )
+ keys = dict((k, check_quote(v)) for k, v in url.query.items())
connectors = ["Provider=SQLOLEDB"]
- if 'port' in keys:
- connectors.append("Data Source=%s, %s" %
- (keys.get("host"), keys.get("port")))
+ if "port" in keys:
+ connectors.append(
+ "Data Source=%s, %s" % (keys.get("host"), keys.get("port"))
+ )
else:
connectors.append("Data Source=%s" % keys.get("host"))
connectors.append("Initial Catalog=%s" % keys.get("database"))
@@ -81,7 +79,9 @@ class MSDialect_adodbapi(MSDialect):
return [[";".join(connectors)], {}]
def is_disconnect(self, e, connection, cursor):
- return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \
- "'connection failure'" in str(e)
+ return isinstance(
+ e, self.dbapi.adodbapi.DatabaseError
+ ) and "'connection failure'" in str(e)
+
dialect = MSDialect_adodbapi
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 9269225d3..161297015 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -655,9 +655,22 @@ from ...sql import compiler, expression, util as sql_util, quoted_name
from ... import engine
from ...engine import reflection, default
from ... import types as sqltypes
-from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \
- FLOAT, DATETIME, DATE, BINARY, \
- TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR
+from ...types import (
+ INTEGER,
+ BIGINT,
+ SMALLINT,
+ DECIMAL,
+ NUMERIC,
+ FLOAT,
+ DATETIME,
+ DATE,
+ BINARY,
+ TEXT,
+ VARCHAR,
+ NVARCHAR,
+ CHAR,
+ NCHAR,
+)
from ...util import update_wrapper
@@ -672,48 +685,202 @@ MS_2005_VERSION = (9,)
MS_2000_VERSION = (8,)
RESERVED_WORDS = set(
- ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization',
- 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade',
- 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce',
- 'collate', 'column', 'commit', 'compute', 'constraint', 'contains',
- 'containstable', 'continue', 'convert', 'create', 'cross', 'current',
- 'current_date', 'current_time', 'current_timestamp', 'current_user',
- 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default',
- 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double',
- 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec',
- 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor',
- 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full',
- 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity',
- 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert',
- 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like',
- 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not',
- 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource',
- 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer',
- 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print',
- 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext',
- 'reconfigure', 'references', 'replication', 'restore', 'restrict',
- 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount',
- 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select',
- 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics',
- 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top',
- 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union',
- 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values',
- 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with',
- 'writetext',
- ])
+ [
+ "add",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "authorization",
+ "backup",
+ "begin",
+ "between",
+ "break",
+ "browse",
+ "bulk",
+ "by",
+ "cascade",
+ "case",
+ "check",
+ "checkpoint",
+ "close",
+ "clustered",
+ "coalesce",
+ "collate",
+ "column",
+ "commit",
+ "compute",
+ "constraint",
+ "contains",
+ "containstable",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "current",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "dbcc",
+ "deallocate",
+ "declare",
+ "default",
+ "delete",
+ "deny",
+ "desc",
+ "disk",
+ "distinct",
+ "distributed",
+ "double",
+ "drop",
+ "dump",
+ "else",
+ "end",
+ "errlvl",
+ "escape",
+ "except",
+ "exec",
+ "execute",
+ "exists",
+ "exit",
+ "external",
+ "fetch",
+ "file",
+ "fillfactor",
+ "for",
+ "foreign",
+ "freetext",
+ "freetexttable",
+ "from",
+ "full",
+ "function",
+ "goto",
+ "grant",
+ "group",
+ "having",
+ "holdlock",
+ "identity",
+ "identity_insert",
+ "identitycol",
+ "if",
+ "in",
+ "index",
+ "inner",
+ "insert",
+ "intersect",
+ "into",
+ "is",
+ "join",
+ "key",
+ "kill",
+ "left",
+ "like",
+ "lineno",
+ "load",
+ "merge",
+ "national",
+ "nocheck",
+ "nonclustered",
+ "not",
+ "null",
+ "nullif",
+ "of",
+ "off",
+ "offsets",
+ "on",
+ "open",
+ "opendatasource",
+ "openquery",
+ "openrowset",
+ "openxml",
+ "option",
+ "or",
+ "order",
+ "outer",
+ "over",
+ "percent",
+ "pivot",
+ "plan",
+ "precision",
+ "primary",
+ "print",
+ "proc",
+ "procedure",
+ "public",
+ "raiserror",
+ "read",
+ "readtext",
+ "reconfigure",
+ "references",
+ "replication",
+ "restore",
+ "restrict",
+ "return",
+ "revert",
+ "revoke",
+ "right",
+ "rollback",
+ "rowcount",
+ "rowguidcol",
+ "rule",
+ "save",
+ "schema",
+ "securityaudit",
+ "select",
+ "session_user",
+ "set",
+ "setuser",
+ "shutdown",
+ "some",
+ "statistics",
+ "system_user",
+ "table",
+ "tablesample",
+ "textsize",
+ "then",
+ "to",
+ "top",
+ "tran",
+ "transaction",
+ "trigger",
+ "truncate",
+ "tsequal",
+ "union",
+ "unique",
+ "unpivot",
+ "update",
+ "updatetext",
+ "use",
+ "user",
+ "values",
+ "varying",
+ "view",
+ "waitfor",
+ "when",
+ "where",
+ "while",
+ "with",
+ "writetext",
+ ]
+)
class REAL(sqltypes.REAL):
- __visit_name__ = 'REAL'
+ __visit_name__ = "REAL"
def __init__(self, **kw):
# REAL is a synonym for FLOAT(24) on SQL server
- kw['precision'] = 24
+ kw["precision"] = 24
super(REAL, self).__init__(**kw)
class TINYINT(sqltypes.Integer):
- __visit_name__ = 'TINYINT'
+ __visit_name__ = "TINYINT"
# MSSQL DATE/TIME types have varied behavior, sometimes returning
@@ -721,14 +888,15 @@ class TINYINT(sqltypes.Integer):
# filter bind parameters into datetime objects (required by pyodbc,
# not sure about other dialects).
-class _MSDate(sqltypes.Date):
+class _MSDate(sqltypes.Date):
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
else:
return value
+
return process
_reg = re.compile(r"(\d+)-(\d+)-(\d+)")
@@ -741,18 +909,16 @@ class _MSDate(sqltypes.Date):
m = self._reg.match(value)
if not m:
raise ValueError(
- "could not parse %r as a date value" % (value, ))
- return datetime.date(*[
- int(x or 0)
- for x in m.groups()
- ])
+ "could not parse %r as a date value" % (value,)
+ )
+ return datetime.date(*[int(x or 0) for x in m.groups()])
else:
return value
+
return process
class TIME(sqltypes.TIME):
-
def __init__(self, precision=None, **kwargs):
self.precision = precision
super(TIME, self).__init__()
@@ -763,10 +929,12 @@ class TIME(sqltypes.TIME):
def process(value):
if isinstance(value, datetime.datetime):
value = datetime.datetime.combine(
- self.__zero_date, value.time())
+ self.__zero_date, value.time()
+ )
elif isinstance(value, datetime.time):
value = datetime.datetime.combine(self.__zero_date, value)
return value
+
return process
_reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?")
@@ -779,24 +947,26 @@ class TIME(sqltypes.TIME):
m = self._reg.match(value)
if not m:
raise ValueError(
- "could not parse %r as a time value" % (value, ))
- return datetime.time(*[
- int(x or 0)
- for x in m.groups()])
+ "could not parse %r as a time value" % (value,)
+ )
+ return datetime.time(*[int(x or 0) for x in m.groups()])
else:
return value
+
return process
+
+
_MSTime = TIME
class _DateTimeBase(object):
-
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
else:
return value
+
return process
@@ -805,11 +975,11 @@ class _MSDateTime(_DateTimeBase, sqltypes.DateTime):
class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
- __visit_name__ = 'SMALLDATETIME'
+ __visit_name__ = "SMALLDATETIME"
class DATETIME2(_DateTimeBase, sqltypes.DateTime):
- __visit_name__ = 'DATETIME2'
+ __visit_name__ = "DATETIME2"
def __init__(self, precision=None, **kw):
super(DATETIME2, self).__init__(**kw)
@@ -818,7 +988,7 @@ class DATETIME2(_DateTimeBase, sqltypes.DateTime):
# TODO: is this not an Interval ?
class DATETIMEOFFSET(sqltypes.TypeEngine):
- __visit_name__ = 'DATETIMEOFFSET'
+ __visit_name__ = "DATETIMEOFFSET"
def __init__(self, precision=None, **kwargs):
self.precision = precision
@@ -847,7 +1017,7 @@ class TIMESTAMP(sqltypes._Binary):
"""
- __visit_name__ = 'TIMESTAMP'
+ __visit_name__ = "TIMESTAMP"
# expected by _Binary to be present
length = None
@@ -866,12 +1036,14 @@ class TIMESTAMP(sqltypes._Binary):
def result_processor(self, dialect, coltype):
super_ = super(TIMESTAMP, self).result_processor(dialect, coltype)
if self.convert_int:
+
def process(value):
value = super_(value)
if value is not None:
# https://stackoverflow.com/a/30403242/34549
- value = int(codecs.encode(value, 'hex'), 16)
+ value = int(codecs.encode(value, "hex"), 16)
return value
+
return process
else:
return super_
@@ -898,7 +1070,7 @@ class ROWVERSION(TIMESTAMP):
"""
- __visit_name__ = 'ROWVERSION'
+ __visit_name__ = "ROWVERSION"
class NTEXT(sqltypes.UnicodeText):
@@ -906,7 +1078,7 @@ class NTEXT(sqltypes.UnicodeText):
"""MSSQL NTEXT type, for variable-length unicode text up to 2^30
characters."""
- __visit_name__ = 'NTEXT'
+ __visit_name__ = "NTEXT"
class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary):
@@ -925,11 +1097,12 @@ class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary):
"""
- __visit_name__ = 'VARBINARY'
+
+ __visit_name__ = "VARBINARY"
class IMAGE(sqltypes.LargeBinary):
- __visit_name__ = 'IMAGE'
+ __visit_name__ = "IMAGE"
class XML(sqltypes.Text):
@@ -943,19 +1116,20 @@ class XML(sqltypes.Text):
.. versionadded:: 1.1.11
"""
- __visit_name__ = 'XML'
+
+ __visit_name__ = "XML"
class BIT(sqltypes.TypeEngine):
- __visit_name__ = 'BIT'
+ __visit_name__ = "BIT"
class MONEY(sqltypes.TypeEngine):
- __visit_name__ = 'MONEY'
+ __visit_name__ = "MONEY"
class SMALLMONEY(sqltypes.TypeEngine):
- __visit_name__ = 'SMALLMONEY'
+ __visit_name__ = "SMALLMONEY"
class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
@@ -963,7 +1137,8 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
class SQL_VARIANT(sqltypes.TypeEngine):
- __visit_name__ = 'SQL_VARIANT'
+ __visit_name__ = "SQL_VARIANT"
+
# old names.
MSDateTime = _MSDateTime
@@ -990,36 +1165,36 @@ MSUniqueIdentifier = UNIQUEIDENTIFIER
MSVariant = SQL_VARIANT
ischema_names = {
- 'int': INTEGER,
- 'bigint': BIGINT,
- 'smallint': SMALLINT,
- 'tinyint': TINYINT,
- 'varchar': VARCHAR,
- 'nvarchar': NVARCHAR,
- 'char': CHAR,
- 'nchar': NCHAR,
- 'text': TEXT,
- 'ntext': NTEXT,
- 'decimal': DECIMAL,
- 'numeric': NUMERIC,
- 'float': FLOAT,
- 'datetime': DATETIME,
- 'datetime2': DATETIME2,
- 'datetimeoffset': DATETIMEOFFSET,
- 'date': DATE,
- 'time': TIME,
- 'smalldatetime': SMALLDATETIME,
- 'binary': BINARY,
- 'varbinary': VARBINARY,
- 'bit': BIT,
- 'real': REAL,
- 'image': IMAGE,
- 'xml': XML,
- 'timestamp': TIMESTAMP,
- 'money': MONEY,
- 'smallmoney': SMALLMONEY,
- 'uniqueidentifier': UNIQUEIDENTIFIER,
- 'sql_variant': SQL_VARIANT,
+ "int": INTEGER,
+ "bigint": BIGINT,
+ "smallint": SMALLINT,
+ "tinyint": TINYINT,
+ "varchar": VARCHAR,
+ "nvarchar": NVARCHAR,
+ "char": CHAR,
+ "nchar": NCHAR,
+ "text": TEXT,
+ "ntext": NTEXT,
+ "decimal": DECIMAL,
+ "numeric": NUMERIC,
+ "float": FLOAT,
+ "datetime": DATETIME,
+ "datetime2": DATETIME2,
+ "datetimeoffset": DATETIMEOFFSET,
+ "date": DATE,
+ "time": TIME,
+ "smalldatetime": SMALLDATETIME,
+ "binary": BINARY,
+ "varbinary": VARBINARY,
+ "bit": BIT,
+ "real": REAL,
+ "image": IMAGE,
+ "xml": XML,
+ "timestamp": TIMESTAMP,
+ "money": MONEY,
+ "smallmoney": SMALLMONEY,
+ "uniqueidentifier": UNIQUEIDENTIFIER,
+ "sql_variant": SQL_VARIANT,
}
@@ -1030,8 +1205,8 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
"""
- if getattr(type_, 'collation', None):
- collation = 'COLLATE %s' % type_.collation
+ if getattr(type_, "collation", None):
+ collation = "COLLATE %s" % type_.collation
else:
collation = None
@@ -1041,15 +1216,14 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
if length:
spec = spec + "(%s)" % length
- return ' '.join([c for c in (spec, collation)
- if c is not None])
+ return " ".join([c for c in (spec, collation) if c is not None])
def visit_FLOAT(self, type_, **kw):
- precision = getattr(type_, 'precision', None)
+ precision = getattr(type_, "precision", None)
if precision is None:
return "FLOAT"
else:
- return "FLOAT(%(precision)s)" % {'precision': precision}
+ return "FLOAT(%(precision)s)" % {"precision": precision}
def visit_TINYINT(self, type_, **kw):
return "TINYINT"
@@ -1061,7 +1235,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return "DATETIMEOFFSET"
def visit_TIME(self, type_, **kw):
- precision = getattr(type_, 'precision', None)
+ precision = getattr(type_, "precision", None)
if precision is not None:
return "TIME(%s)" % precision
else:
@@ -1074,7 +1248,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return "ROWVERSION"
def visit_DATETIME2(self, type_, **kw):
- precision = getattr(type_, 'precision', None)
+ precision = getattr(type_, "precision", None)
if precision is not None:
return "DATETIME2(%s)" % precision
else:
@@ -1105,7 +1279,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return self._extend("TEXT", type_)
def visit_VARCHAR(self, type_, **kw):
- return self._extend("VARCHAR", type_, length=type_.length or 'max')
+ return self._extend("VARCHAR", type_, length=type_.length or "max")
def visit_CHAR(self, type_, **kw):
return self._extend("CHAR", type_)
@@ -1114,7 +1288,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return self._extend("NCHAR", type_)
def visit_NVARCHAR(self, type_, **kw):
- return self._extend("NVARCHAR", type_, length=type_.length or 'max')
+ return self._extend("NVARCHAR", type_, length=type_.length or "max")
def visit_date(self, type_, **kw):
if self.dialect.server_version_info < MS_2008_VERSION:
@@ -1141,10 +1315,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return "XML"
def visit_VARBINARY(self, type_, **kw):
- return self._extend(
- "VARBINARY",
- type_,
- length=type_.length or 'max')
+ return self._extend("VARBINARY", type_, length=type_.length or "max")
def visit_boolean(self, type_, **kw):
return self.visit_BIT(type_)
@@ -1156,13 +1327,13 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return "MONEY"
def visit_SMALLMONEY(self, type_, **kw):
- return 'SMALLMONEY'
+ return "SMALLMONEY"
def visit_UNIQUEIDENTIFIER(self, type_, **kw):
return "UNIQUEIDENTIFIER"
def visit_SQL_VARIANT(self, type_, **kw):
- return 'SQL_VARIANT'
+ return "SQL_VARIANT"
class MSExecutionContext(default.DefaultExecutionContext):
@@ -1186,41 +1357,44 @@ class MSExecutionContext(default.DefaultExecutionContext):
insert_has_sequence = seq_column is not None
if insert_has_sequence:
- self._enable_identity_insert = \
- seq_column.key in self.compiled_parameters[0] or \
- (
- self.compiled.statement.parameters and (
- (
- self.compiled.statement._has_multi_parameters
- and
- seq_column.key in
- self.compiled.statement.parameters[0]
- ) or (
- not
- self.compiled.statement._has_multi_parameters
- and
- seq_column.key in
- self.compiled.statement.parameters
- )
+ self._enable_identity_insert = seq_column.key in self.compiled_parameters[
+ 0
+ ] or (
+ self.compiled.statement.parameters
+ and (
+ (
+ self.compiled.statement._has_multi_parameters
+ and seq_column.key
+ in self.compiled.statement.parameters[0]
+ )
+ or (
+ not self.compiled.statement._has_multi_parameters
+ and seq_column.key
+ in self.compiled.statement.parameters
)
)
+ )
else:
self._enable_identity_insert = False
- self._select_lastrowid = not self.compiled.inline and \
- insert_has_sequence and \
- not self.compiled.returning and \
- not self._enable_identity_insert and \
- not self.executemany
+ self._select_lastrowid = (
+ not self.compiled.inline
+ and insert_has_sequence
+ and not self.compiled.returning
+ and not self._enable_identity_insert
+ and not self.executemany
+ )
if self._enable_identity_insert:
self.root_connection._cursor_execute(
self.cursor,
self._opt_encode(
- "SET IDENTITY_INSERT %s ON" %
- self.dialect.identifier_preparer.format_table(tbl)),
+ "SET IDENTITY_INSERT %s ON"
+ % self.dialect.identifier_preparer.format_table(tbl)
+ ),
(),
- self)
+ self,
+ )
def post_exec(self):
"""Disable IDENTITY_INSERT if enabled."""
@@ -1230,29 +1404,35 @@ class MSExecutionContext(default.DefaultExecutionContext):
if self.dialect.use_scope_identity:
conn._cursor_execute(
self.cursor,
- "SELECT scope_identity() AS lastrowid", (), self)
+ "SELECT scope_identity() AS lastrowid",
+ (),
+ self,
+ )
else:
- conn._cursor_execute(self.cursor,
- "SELECT @@identity AS lastrowid",
- (),
- self)
+ conn._cursor_execute(
+ self.cursor, "SELECT @@identity AS lastrowid", (), self
+ )
# fetchall() ensures the cursor is consumed without closing it
row = self.cursor.fetchall()[0]
self._lastrowid = int(row[0])
- if (self.isinsert or self.isupdate or self.isdelete) and \
- self.compiled.returning:
+ if (
+ self.isinsert or self.isupdate or self.isdelete
+ ) and self.compiled.returning:
self._result_proxy = engine.FullyBufferedResultProxy(self)
if self._enable_identity_insert:
conn._cursor_execute(
self.cursor,
self._opt_encode(
- "SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer. format_table(
- self.compiled.statement.table)),
+ "SET IDENTITY_INSERT %s OFF"
+ % self.dialect.identifier_preparer.format_table(
+ self.compiled.statement.table
+ )
+ ),
(),
- self)
+ self,
+ )
def get_lastrowid(self):
return self._lastrowid
@@ -1262,9 +1442,12 @@ class MSExecutionContext(default.DefaultExecutionContext):
try:
self.cursor.execute(
self._opt_encode(
- "SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer. format_table(
- self.compiled.statement.table)))
+ "SET IDENTITY_INSERT %s OFF"
+ % self.dialect.identifier_preparer.format_table(
+ self.compiled.statement.table
+ )
+ )
+ )
except Exception:
pass
@@ -1281,11 +1464,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
extract_map = util.update_copy(
compiler.SQLCompiler.extract_map,
{
- 'doy': 'dayofyear',
- 'dow': 'weekday',
- 'milliseconds': 'millisecond',
- 'microseconds': 'microsecond'
- })
+ "doy": "dayofyear",
+ "dow": "weekday",
+ "milliseconds": "millisecond",
+ "microseconds": "microsecond",
+ },
+ )
def __init__(self, *args, **kwargs):
self.tablealiases = {}
@@ -1298,6 +1482,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
else:
super_ = getattr(super(MSSQLCompiler, self), fn.__name__)
return super_(*arg, **kw)
+
return decorate
def visit_now_func(self, fn, **kw):
@@ -1313,20 +1498,22 @@ class MSSQLCompiler(compiler.SQLCompiler):
return "LEN%s" % self.function_argspec(fn, **kw)
def visit_concat_op_binary(self, binary, operator, **kw):
- return "%s + %s" % \
- (self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ return "%s + %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
def visit_true(self, expr, **kw):
- return '1'
+ return "1"
def visit_false(self, expr, **kw):
- return '0'
+ return "0"
def visit_match_op_binary(self, binary, operator, **kw):
return "CONTAINS (%s, %s)" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def get_select_precolumns(self, select, **kw):
""" MS-SQL puts TOP, it's version of LIMIT here """
@@ -1345,7 +1532,8 @@ class MSSQLCompiler(compiler.SQLCompiler):
return s
else:
return compiler.SQLCompiler.get_select_precolumns(
- self, select, **kw)
+ self, select, **kw
+ )
def get_from_hint_text(self, table, text):
return text
@@ -1363,20 +1551,21 @@ class MSSQLCompiler(compiler.SQLCompiler):
"""
if (
- (
- not select._simple_int_limit and
- select._limit_clause is not None
- ) or (
- select._offset_clause is not None and
- not select._simple_int_offset or select._offset
+ (not select._simple_int_limit and select._limit_clause is not None)
+ or (
+ select._offset_clause is not None
+ and not select._simple_int_offset
+ or select._offset
)
- ) and not getattr(select, '_mssql_visit', None):
+ ) and not getattr(select, "_mssql_visit", None):
# to use ROW_NUMBER(), an ORDER BY is required.
if not select._order_by_clause.clauses:
- raise exc.CompileError('MSSQL requires an order_by when '
- 'using an OFFSET or a non-simple '
- 'LIMIT clause')
+ raise exc.CompileError(
+ "MSSQL requires an order_by when "
+ "using an OFFSET or a non-simple "
+ "LIMIT clause"
+ )
_order_by_clauses = [
sql_util.unwrap_label_reference(elem)
@@ -1385,24 +1574,31 @@ class MSSQLCompiler(compiler.SQLCompiler):
limit_clause = select._limit_clause
offset_clause = select._offset_clause
- kwargs['select_wraps_for'] = select
+ kwargs["select_wraps_for"] = select
select = select._generate()
select._mssql_visit = True
- select = select.column(
- sql.func.ROW_NUMBER().over(order_by=_order_by_clauses)
- .label("mssql_rn")).order_by(None).alias()
+ select = (
+ select.column(
+ sql.func.ROW_NUMBER()
+ .over(order_by=_order_by_clauses)
+ .label("mssql_rn")
+ )
+ .order_by(None)
+ .alias()
+ )
- mssql_rn = sql.column('mssql_rn')
- limitselect = sql.select([c for c in select.c if
- c.key != 'mssql_rn'])
+ mssql_rn = sql.column("mssql_rn")
+ limitselect = sql.select(
+ [c for c in select.c if c.key != "mssql_rn"]
+ )
if offset_clause is not None:
limitselect.append_whereclause(mssql_rn > offset_clause)
if limit_clause is not None:
limitselect.append_whereclause(
- mssql_rn <= (limit_clause + offset_clause))
+ mssql_rn <= (limit_clause + offset_clause)
+ )
else:
- limitselect.append_whereclause(
- mssql_rn <= (limit_clause))
+ limitselect.append_whereclause(mssql_rn <= (limit_clause))
return self.process(limitselect, **kwargs)
else:
return compiler.SQLCompiler.visit_select(self, select, **kwargs)
@@ -1422,35 +1618,38 @@ class MSSQLCompiler(compiler.SQLCompiler):
@_with_legacy_schema_aliasing
def visit_alias(self, alias, **kw):
# translate for schema-qualified table aliases
- kw['mssql_aliased'] = alias.original
+ kw["mssql_aliased"] = alias.original
return super(MSSQLCompiler, self).visit_alias(alias, **kw)
@_with_legacy_schema_aliasing
def visit_column(self, column, add_to_result_map=None, **kw):
- if column.table is not None and \
- (not self.isupdate and not self.isdelete) or \
- self.is_subquery():
+ if (
+ column.table is not None
+ and (not self.isupdate and not self.isdelete)
+ or self.is_subquery()
+ ):
# translate for schema-qualified table aliases
t = self._schema_aliased_table(column.table)
if t is not None:
converted = expression._corresponding_column_or_error(
- t, column)
+ t, column
+ )
if add_to_result_map is not None:
add_to_result_map(
column.name,
column.name,
(column, column.name, column.key),
- column.type
+ column.type,
)
- return super(MSSQLCompiler, self).\
- visit_column(converted, **kw)
+ return super(MSSQLCompiler, self).visit_column(converted, **kw)
return super(MSSQLCompiler, self).visit_column(
- column, add_to_result_map=add_to_result_map, **kw)
+ column, add_to_result_map=add_to_result_map, **kw
+ )
def _schema_aliased_table(self, table):
- if getattr(table, 'schema', None) is not None:
+ if getattr(table, "schema", None) is not None:
if table not in self.tablealiases:
self.tablealiases[table] = table.alias()
return self.tablealiases[table]
@@ -1459,16 +1658,17 @@ class MSSQLCompiler(compiler.SQLCompiler):
def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
- return 'DATEPART(%s, %s)' % \
- (field, self.process(extract.expr, **kw))
+ return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw))
def visit_savepoint(self, savepoint_stmt):
- return "SAVE TRANSACTION %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
def visit_rollback_to_savepoint(self, savepoint_stmt):
- return ("ROLLBACK TRANSACTION %s"
- % self.preparer.format_savepoint(savepoint_stmt))
+ return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
def visit_binary(self, binary, **kwargs):
"""Move bind parameters to the right-hand side of an operator, where
@@ -1481,10 +1681,11 @@ class MSSQLCompiler(compiler.SQLCompiler):
and not isinstance(binary.right, expression.BindParameter)
):
return self.process(
- expression.BinaryExpression(binary.right,
- binary.left,
- binary.operator),
- **kwargs)
+ expression.BinaryExpression(
+ binary.right, binary.left, binary.operator
+ ),
+ **kwargs
+ )
return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
def returning_clause(self, stmt, returning_cols):
@@ -1497,12 +1698,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
adapter = sql_util.ClauseAdapter(target)
columns = [
- self._label_select_column(None, adapter.traverse(c),
- True, False, {})
+ self._label_select_column(
+ None, adapter.traverse(c), True, False, {}
+ )
for c in expression._select_iterables(returning_cols)
]
- return 'OUTPUT ' + ', '.join(columns)
+ return "OUTPUT " + ", ".join(columns)
def get_cte_preamble(self, recursive):
# SQL Server finds it too inconvenient to accept
@@ -1515,13 +1717,14 @@ class MSSQLCompiler(compiler.SQLCompiler):
if isinstance(column, expression.Function):
return column.label(None)
else:
- return super(MSSQLCompiler, self).\
- label_select_column(select, column, asfrom)
+ return super(MSSQLCompiler, self).label_select_column(
+ select, column, asfrom
+ )
def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which
# SQLAlchemy doesn't use
- return ''
+ return ""
def order_by_clause(self, select, **kw):
order_by = self.process(select._order_by_clause, **kw)
@@ -1532,10 +1735,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
else:
return ""
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Render the UPDATE..FROM clause specific to MSSQL.
In MSSQL, if the UPDATE statement involves an alias of the table to
@@ -1543,13 +1745,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
well. Otherwise, it is optional. Here, we add it regardless.
"""
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in [from_table] + extra_froms)
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
@@ -1558,20 +1759,21 @@ class MSSQLCompiler(compiler.SQLCompiler):
self, asfrom=True, iscrud=True, ashint=ashint
)
- def delete_extra_from_clause(self, delete_stmt, from_table,
- extra_froms, from_hints, **kw):
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Render the DELETE .. FROM clause specific to MSSQL.
Yes, it has the FROM keyword twice.
"""
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in [from_table] + extra_froms)
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
def visit_empty_set_expr(self, type_):
- return 'SELECT 1 WHERE 1!=1'
+ return "SELECT 1 WHERE 1!=1"
class MSSQLStrictCompiler(MSSQLCompiler):
@@ -1583,20 +1785,21 @@ class MSSQLStrictCompiler(MSSQLCompiler):
binds are used.
"""
+
ansi_bind_rules = True
def visit_in_op_binary(self, binary, operator, **kw):
- kw['literal_binds'] = True
+ kw["literal_binds"] = True
return "%s IN %s" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def visit_notin_op_binary(self, binary, operator, **kw):
- kw['literal_binds'] = True
+ kw["literal_binds"] = True
return "%s NOT IN %s" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def render_literal_value(self, value, type_):
@@ -1615,23 +1818,28 @@ class MSSQLStrictCompiler(MSSQLCompiler):
# SQL Server wants single quotes around the date string.
return "'" + str(value) + "'"
else:
- return super(MSSQLStrictCompiler, self).\
- render_literal_value(value, type_)
+ return super(MSSQLStrictCompiler, self).render_literal_value(
+ value, type_
+ )
class MSDDLCompiler(compiler.DDLCompiler):
-
def get_column_specification(self, column, **kwargs):
colspec = (
- self.preparer.format_column(column) + " "
+ self.preparer.format_column(column)
+ + " "
+ self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ column.type, type_expression=column
+ )
)
if column.nullable is not None:
- if not column.nullable or column.primary_key or \
- isinstance(column.default, sa_schema.Sequence) or \
- column.autoincrement is True:
+ if (
+ not column.nullable
+ or column.primary_key
+ or isinstance(column.default, sa_schema.Sequence)
+ or column.autoincrement is True
+ ):
colspec += " NOT NULL"
else:
colspec += " NULL"
@@ -1639,15 +1847,18 @@ class MSDDLCompiler(compiler.DDLCompiler):
if column.table is None:
raise exc.CompileError(
"mssql requires Table-bound columns "
- "in order to generate DDL")
+ "in order to generate DDL"
+ )
# install an IDENTITY Sequence if we either a sequence or an implicit
# IDENTITY column
if isinstance(column.default, sa_schema.Sequence):
- if (column.default.start is not None or
- column.default.increment is not None or
- column is not column.table._autoincrement_column):
+ if (
+ column.default.start is not None
+ or column.default.increment is not None
+ or column is not column.table._autoincrement_column
+ ):
util.warn_deprecated(
"Use of Sequence with SQL Server in order to affect the "
"parameters of the IDENTITY value is deprecated, as "
@@ -1655,18 +1866,23 @@ class MSDDLCompiler(compiler.DDLCompiler):
"will correspond to an actual SQL Server "
"CREATE SEQUENCE in "
"a future release. Please use the mssql_identity_start "
- "and mssql_identity_increment parameters.")
+ "and mssql_identity_increment parameters."
+ )
if column.default.start == 0:
start = 0
else:
start = column.default.start or 1
- colspec += " IDENTITY(%s,%s)" % (start,
- column.default.increment or 1)
- elif column is column.table._autoincrement_column or \
- column.autoincrement is True:
- start = column.dialect_options['mssql']['identity_start']
- increment = column.dialect_options['mssql']['identity_increment']
+ colspec += " IDENTITY(%s,%s)" % (
+ start,
+ column.default.increment or 1,
+ )
+ elif (
+ column is column.table._autoincrement_column
+ or column.autoincrement is True
+ ):
+ start = column.dialect_options["mssql"]["identity_start"]
+ increment = column.dialect_options["mssql"]["identity_increment"]
colspec += " IDENTITY(%s,%s)" % (start, increment)
else:
default = self.get_column_default_string(column)
@@ -1684,84 +1900,88 @@ class MSDDLCompiler(compiler.DDLCompiler):
text += "UNIQUE "
# handle clustering option
- clustered = index.dialect_options['mssql']['clustered']
+ clustered = index.dialect_options["mssql"]["clustered"]
if clustered is not None:
if clustered:
text += "CLUSTERED "
else:
text += "NONCLUSTERED "
- text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=include_schema),
- preparer.format_table(index.table),
- ', '.join(
- self.sql_compiler.process(expr,
- include_table=False,
- literal_binds=True) for
- expr in index.expressions)
- )
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(index.table),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
# handle other included columns
- if index.dialect_options['mssql']['include']:
- inclusions = [index.table.c[col]
- if isinstance(col, util.string_types) else col
- for col in
- index.dialect_options['mssql']['include']
- ]
+ if index.dialect_options["mssql"]["include"]:
+ inclusions = [
+ index.table.c[col]
+ if isinstance(col, util.string_types)
+ else col
+ for col in index.dialect_options["mssql"]["include"]
+ ]
- text += " INCLUDE (%s)" \
- % ', '.join([preparer.quote(c.name)
- for c in inclusions])
+ text += " INCLUDE (%s)" % ", ".join(
+ [preparer.quote(c.name) for c in inclusions]
+ )
return text
def visit_drop_index(self, drop):
return "\nDROP INDEX %s ON %s" % (
self._prepared_index_name(drop.element, include_schema=False),
- self.preparer.format_table(drop.element.table)
+ self.preparer.format_table(drop.element.table),
)
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
- text += "CONSTRAINT %s " % \
- self.preparer.format_constraint(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
text += "PRIMARY KEY "
- clustered = constraint.dialect_options['mssql']['clustered']
+ clustered = constraint.dialect_options["mssql"]["clustered"]
if clustered is not None:
if clustered:
text += "CLUSTERED "
else:
text += "NONCLUSTERED "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in constraint)
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name) for c in constraint
+ )
text += self.define_constraint_deferrability(constraint)
return text
def visit_unique_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
- text += "CONSTRAINT %s " % \
- self.preparer.format_constraint(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
text += "UNIQUE "
- clustered = constraint.dialect_options['mssql']['clustered']
+ clustered = constraint.dialect_options["mssql"]["clustered"]
if clustered is not None:
if clustered:
text += "CLUSTERED "
else:
text += "NONCLUSTERED "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in constraint)
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name) for c in constraint
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -1771,8 +1991,11 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(MSIdentifierPreparer, self).__init__(
- dialect, initial_quote='[',
- final_quote=']', quote_case_sensitive_collations=False)
+ dialect,
+ initial_quote="[",
+ final_quote="]",
+ quote_case_sensitive_collations=False,
+ )
def _escape_identifier(self, value):
return value
@@ -1783,7 +2006,9 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
dbname, owner = _schema_elements(schema)
if dbname:
result = "%s.%s" % (
- self.quote(dbname, force), self.quote(owner, force))
+ self.quote(dbname, force),
+ self.quote(owner, force),
+ )
elif owner:
result = self.quote(owner, force)
else:
@@ -1794,16 +2019,37 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
def _db_plus_owner_listing(fn):
def wrap(dialect, connection, schema=None, **kw):
dbname, owner = _owner_plus_db(dialect, schema)
- return _switch_db(dbname, connection, fn, dialect, connection,
- dbname, owner, schema, **kw)
+ return _switch_db(
+ dbname,
+ connection,
+ fn,
+ dialect,
+ connection,
+ dbname,
+ owner,
+ schema,
+ **kw
+ )
+
return update_wrapper(wrap, fn)
def _db_plus_owner(fn):
def wrap(dialect, connection, tablename, schema=None, **kw):
dbname, owner = _owner_plus_db(dialect, schema)
- return _switch_db(dbname, connection, fn, dialect, connection,
- tablename, dbname, owner, schema, **kw)
+ return _switch_db(
+ dbname,
+ connection,
+ fn,
+ dialect,
+ connection,
+ tablename,
+ dbname,
+ owner,
+ schema,
+ **kw
+ )
+
return update_wrapper(wrap, fn)
@@ -1837,9 +2083,9 @@ def _schema_elements(schema):
for token in re.split(r"(\[|\]|\.)", schema):
if not token:
continue
- if token == '[':
+ if token == "[":
bracket = True
- elif token == ']':
+ elif token == "]":
bracket = False
elif not bracket and token == ".":
push.append(symbol)
@@ -1857,7 +2103,7 @@ def _schema_elements(schema):
class MSDialect(default.DefaultDialect):
- name = 'mssql'
+ name = "mssql"
supports_default_values = True
supports_empty_insert = False
execution_ctx_cls = MSExecutionContext
@@ -1871,9 +2117,9 @@ class MSDialect(default.DefaultDialect):
sqltypes.Time: TIME,
}
- engine_config_types = default.DefaultDialect.engine_config_types.union([
- ('legacy_schema_aliasing', util.asbool),
- ])
+ engine_config_types = default.DefaultDialect.engine_config_types.union(
+ [("legacy_schema_aliasing", util.asbool)]
+ )
ischema_names = ischema_names
@@ -1890,36 +2136,30 @@ class MSDialect(default.DefaultDialect):
preparer = MSIdentifierPreparer
construct_arguments = [
- (sa_schema.PrimaryKeyConstraint, {
- "clustered": None
- }),
- (sa_schema.UniqueConstraint, {
- "clustered": None
- }),
- (sa_schema.Index, {
- "clustered": None,
- "include": None
- }),
- (sa_schema.Column, {
- "identity_start": 1,
- "identity_increment": 1
- })
+ (sa_schema.PrimaryKeyConstraint, {"clustered": None}),
+ (sa_schema.UniqueConstraint, {"clustered": None}),
+ (sa_schema.Index, {"clustered": None, "include": None}),
+ (sa_schema.Column, {"identity_start": 1, "identity_increment": 1}),
]
- def __init__(self,
- query_timeout=None,
- use_scope_identity=True,
- max_identifier_length=None,
- schema_name="dbo",
- isolation_level=None,
- deprecate_large_types=None,
- legacy_schema_aliasing=False, **opts):
+ def __init__(
+ self,
+ query_timeout=None,
+ use_scope_identity=True,
+ max_identifier_length=None,
+ schema_name="dbo",
+ isolation_level=None,
+ deprecate_large_types=None,
+ legacy_schema_aliasing=False,
+ **opts
+ ):
self.query_timeout = int(query_timeout or 0)
self.schema_name = schema_name
self.use_scope_identity = use_scope_identity
- self.max_identifier_length = int(max_identifier_length or 0) or \
- self.max_identifier_length
+ self.max_identifier_length = (
+ int(max_identifier_length or 0) or self.max_identifier_length
+ )
self.deprecate_large_types = deprecate_large_types
self.legacy_schema_aliasing = legacy_schema_aliasing
@@ -1936,27 +2176,33 @@ class MSDialect(default.DefaultDialect):
# SQL Server does not support RELEASE SAVEPOINT
pass
- _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
- 'READ COMMITTED', 'REPEATABLE READ',
- 'SNAPSHOT'])
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "SNAPSHOT",
+ ]
+ )
def set_isolation_level(self, connection, level):
- level = level.replace('_', ' ')
+ level = level.replace("_", " ")
if level not in self._isolation_lookup:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
- "Valid isolation levels for %s are %s" %
- (level, self.name, ", ".join(self._isolation_lookup))
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
)
cursor = connection.cursor()
- cursor.execute(
- "SET TRANSACTION ISOLATION LEVEL %s" % level)
+ cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level)
cursor.close()
def get_isolation_level(self, connection):
if self.server_version_info < MS_2005_VERSION:
raise NotImplementedError(
- "Can't fetch isolation level prior to SQL Server 2005")
+ "Can't fetch isolation level prior to SQL Server 2005"
+ )
last_error = None
@@ -1964,7 +2210,8 @@ class MSDialect(default.DefaultDialect):
for view in views:
cursor = connection.cursor()
try:
- cursor.execute("""
+ cursor.execute(
+ """
SELECT CASE transaction_isolation_level
WHEN 0 THEN NULL
WHEN 1 THEN 'READ UNCOMMITTED'
@@ -1974,7 +2221,9 @@ class MSDialect(default.DefaultDialect):
WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL
FROM %s
where session_id = @@SPID
- """ % view)
+ """
+ % view
+ )
val = cursor.fetchone()[0]
except self.dbapi.Error as err:
# Python3 scoping rules
@@ -1987,7 +2236,8 @@ class MSDialect(default.DefaultDialect):
else:
util.warn(
"Could not fetch transaction isolation level, "
- "tried views: %s; final error was: %s" % (views, last_error))
+ "tried views: %s; final error was: %s" % (views, last_error)
+ )
raise NotImplementedError(
"Can't fetch isolation level on this particular "
@@ -2000,8 +2250,10 @@ class MSDialect(default.DefaultDialect):
def on_connect(self):
if self.isolation_level is not None:
+
def connect(conn):
self.set_isolation_level(conn, self.isolation_level)
+
return connect
else:
return None
@@ -2010,16 +2262,20 @@ class MSDialect(default.DefaultDialect):
if self.server_version_info[0] not in list(range(8, 17)):
util.warn(
"Unrecognized server version info '%s'. Some SQL Server "
- "features may not function properly." %
- ".".join(str(x) for x in self.server_version_info))
- if self.server_version_info >= MS_2005_VERSION and \
- 'implicit_returning' not in self.__dict__:
+ "features may not function properly."
+ % ".".join(str(x) for x in self.server_version_info)
+ )
+ if (
+ self.server_version_info >= MS_2005_VERSION
+ and "implicit_returning" not in self.__dict__
+ ):
self.implicit_returning = True
if self.server_version_info >= MS_2008_VERSION:
self.supports_multivalues_insert = True
if self.deprecate_large_types is None:
- self.deprecate_large_types = \
+ self.deprecate_large_types = (
self.server_version_info >= MS_2012_VERSION
+ )
def _get_default_schema_name(self, connection):
if self.server_version_info < MS_2005_VERSION:
@@ -2039,17 +2295,19 @@ class MSDialect(default.DefaultDialect):
whereclause = columns.c.table_name == tablename
if owner:
- whereclause = sql.and_(whereclause,
- columns.c.table_schema == owner)
+ whereclause = sql.and_(
+ whereclause, columns.c.table_schema == owner
+ )
s = sql.select([columns], whereclause)
c = connection.execute(s)
return c.first() is not None
@reflection.cache
def get_schema_names(self, connection, **kw):
- s = sql.select([ischema.schemata.c.schema_name],
- order_by=[ischema.schemata.c.schema_name]
- )
+ s = sql.select(
+ [ischema.schemata.c.schema_name],
+ order_by=[ischema.schemata.c.schema_name],
+ )
schema_names = [r[0] for r in connection.execute(s)]
return schema_names
@@ -2057,12 +2315,13 @@ class MSDialect(default.DefaultDialect):
@_db_plus_owner_listing
def get_table_names(self, connection, dbname, owner, schema, **kw):
tables = ischema.tables
- s = sql.select([tables.c.table_name],
- sql.and_(
- tables.c.table_schema == owner,
- tables.c.table_type == 'BASE TABLE'
- ),
- order_by=[tables.c.table_name]
+ s = sql.select(
+ [tables.c.table_name],
+ sql.and_(
+ tables.c.table_schema == owner,
+ tables.c.table_type == "BASE TABLE",
+ ),
+ order_by=[tables.c.table_name],
)
table_names = [r[0] for r in connection.execute(s)]
return table_names
@@ -2071,12 +2330,12 @@ class MSDialect(default.DefaultDialect):
@_db_plus_owner_listing
def get_view_names(self, connection, dbname, owner, schema, **kw):
tables = ischema.tables
- s = sql.select([tables.c.table_name],
- sql.and_(
- tables.c.table_schema == owner,
- tables.c.table_type == 'VIEW'
- ),
- order_by=[tables.c.table_name]
+ s = sql.select(
+ [tables.c.table_name],
+ sql.and_(
+ tables.c.table_schema == owner, tables.c.table_type == "VIEW"
+ ),
+ order_by=[tables.c.table_name],
)
view_names = [r[0] for r in connection.execute(s)]
return view_names
@@ -2090,30 +2349,33 @@ class MSDialect(default.DefaultDialect):
return []
rp = connection.execute(
- sql.text("select ind.index_id, ind.is_unique, ind.name "
- "from sys.indexes as ind join sys.tables as tab on "
- "ind.object_id=tab.object_id "
- "join sys.schemas as sch on sch.schema_id=tab.schema_id "
- "where tab.name = :tabname "
- "and sch.name=:schname "
- "and ind.is_primary_key=0 and ind.type != 0",
- bindparams=[
- sql.bindparam('tabname', tablename,
- sqltypes.String(convert_unicode=True)),
- sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
- ],
- typemap={
- 'name': sqltypes.Unicode()
- }
- )
+ sql.text(
+ "select ind.index_id, ind.is_unique, ind.name "
+ "from sys.indexes as ind join sys.tables as tab on "
+ "ind.object_id=tab.object_id "
+ "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+ "where tab.name = :tabname "
+ "and sch.name=:schname "
+ "and ind.is_primary_key=0 and ind.type != 0",
+ bindparams=[
+ sql.bindparam(
+ "tabname",
+ tablename,
+ sqltypes.String(convert_unicode=True),
+ ),
+ sql.bindparam(
+ "schname", owner, sqltypes.String(convert_unicode=True)
+ ),
+ ],
+ typemap={"name": sqltypes.Unicode()},
+ )
)
indexes = {}
for row in rp:
- indexes[row['index_id']] = {
- 'name': row['name'],
- 'unique': row['is_unique'] == 1,
- 'column_names': []
+ indexes[row["index_id"]] = {
+ "name": row["name"],
+ "unique": row["is_unique"] == 1,
+ "column_names": [],
}
rp = connection.execute(
sql.text(
@@ -2127,24 +2389,29 @@ class MSDialect(default.DefaultDialect):
"where tab.name=:tabname "
"and sch.name=:schname",
bindparams=[
- sql.bindparam('tabname', tablename,
- sqltypes.String(convert_unicode=True)),
- sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
+ sql.bindparam(
+ "tabname",
+ tablename,
+ sqltypes.String(convert_unicode=True),
+ ),
+ sql.bindparam(
+ "schname", owner, sqltypes.String(convert_unicode=True)
+ ),
],
- typemap={'name': sqltypes.Unicode()}
- ),
+ typemap={"name": sqltypes.Unicode()},
+ )
)
for row in rp:
- if row['index_id'] in indexes:
- indexes[row['index_id']]['column_names'].append(row['name'])
+ if row["index_id"] in indexes:
+ indexes[row["index_id"]]["column_names"].append(row["name"])
return list(indexes.values())
@reflection.cache
@_db_plus_owner
- def get_view_definition(self, connection, viewname,
- dbname, owner, schema, **kw):
+ def get_view_definition(
+ self, connection, viewname, dbname, owner, schema, **kw
+ ):
rp = connection.execute(
sql.text(
"select definition from sys.sql_modules as mod, "
@@ -2155,11 +2422,15 @@ class MSDialect(default.DefaultDialect):
"views.schema_id=sch.schema_id and "
"views.name=:viewname and sch.name=:schname",
bindparams=[
- sql.bindparam('viewname', viewname,
- sqltypes.String(convert_unicode=True)),
- sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
- ]
+ sql.bindparam(
+ "viewname",
+ viewname,
+ sqltypes.String(convert_unicode=True),
+ ),
+ sql.bindparam(
+ "schname", owner, sqltypes.String(convert_unicode=True)
+ ),
+ ],
)
)
@@ -2173,12 +2444,15 @@ class MSDialect(default.DefaultDialect):
# Get base columns
columns = ischema.columns
if owner:
- whereclause = sql.and_(columns.c.table_name == tablename,
- columns.c.table_schema == owner)
+ whereclause = sql.and_(
+ columns.c.table_name == tablename,
+ columns.c.table_schema == owner,
+ )
else:
whereclause = columns.c.table_name == tablename
- s = sql.select([columns], whereclause,
- order_by=[columns.c.ordinal_position])
+ s = sql.select(
+ [columns], whereclause, order_by=[columns.c.ordinal_position]
+ )
c = connection.execute(s)
cols = []
@@ -2186,57 +2460,76 @@ class MSDialect(default.DefaultDialect):
row = c.fetchone()
if row is None:
break
- (name, type, nullable, charlen,
- numericprec, numericscale, default, collation) = (
+ (
+ name,
+ type,
+ nullable,
+ charlen,
+ numericprec,
+ numericscale,
+ default,
+ collation,
+ ) = (
row[columns.c.column_name],
row[columns.c.data_type],
- row[columns.c.is_nullable] == 'YES',
+ row[columns.c.is_nullable] == "YES",
row[columns.c.character_maximum_length],
row[columns.c.numeric_precision],
row[columns.c.numeric_scale],
row[columns.c.column_default],
- row[columns.c.collation_name]
+ row[columns.c.collation_name],
)
coltype = self.ischema_names.get(type, None)
kwargs = {}
- if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText,
- MSNText, MSBinary, MSVarBinary,
- sqltypes.LargeBinary):
+ if coltype in (
+ MSString,
+ MSChar,
+ MSNVarchar,
+ MSNChar,
+ MSText,
+ MSNText,
+ MSBinary,
+ MSVarBinary,
+ sqltypes.LargeBinary,
+ ):
if charlen == -1:
charlen = None
- kwargs['length'] = charlen
+ kwargs["length"] = charlen
if collation:
- kwargs['collation'] = collation
+ kwargs["collation"] = collation
if coltype is None:
util.warn(
- "Did not recognize type '%s' of column '%s'" %
- (type, name))
+ "Did not recognize type '%s' of column '%s'" % (type, name)
+ )
coltype = sqltypes.NULLTYPE
else:
- if issubclass(coltype, sqltypes.Numeric) and \
- coltype is not MSReal:
- kwargs['scale'] = numericscale
- kwargs['precision'] = numericprec
+ if (
+ issubclass(coltype, sqltypes.Numeric)
+ and coltype is not MSReal
+ ):
+ kwargs["scale"] = numericscale
+ kwargs["precision"] = numericprec
coltype = coltype(**kwargs)
cdict = {
- 'name': name,
- 'type': coltype,
- 'nullable': nullable,
- 'default': default,
- 'autoincrement': False,
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": False,
}
cols.append(cdict)
# autoincrement and identity
colmap = {}
for col in cols:
- colmap[col['name']] = col
+ colmap[col["name"]] = col
# We also run an sp_columns to check for identity columns:
- cursor = connection.execute("sp_columns @table_name = '%s', "
- "@table_owner = '%s'"
- % (tablename, owner))
+ cursor = connection.execute(
+ "sp_columns @table_name = '%s', "
+ "@table_owner = '%s'" % (tablename, owner)
+ )
ic = None
while True:
row = cursor.fetchone()
@@ -2245,10 +2538,10 @@ class MSDialect(default.DefaultDialect):
(col_name, type_name) = row[3], row[5]
if type_name.endswith("identity") and col_name in colmap:
ic = col_name
- colmap[col_name]['autoincrement'] = True
- colmap[col_name]['dialect_options'] = {
- 'mssql_identity_start': 1,
- 'mssql_identity_increment': 1
+ colmap[col_name]["autoincrement"] = True
+ colmap[col_name]["dialect_options"] = {
+ "mssql_identity_start": 1,
+ "mssql_identity_increment": 1,
}
break
cursor.close()
@@ -2262,64 +2555,74 @@ class MSDialect(default.DefaultDialect):
row = cursor.first()
if row is not None and row[0] is not None:
- colmap[ic]['dialect_options'].update({
- 'mssql_identity_start': int(row[0]),
- 'mssql_identity_increment': int(row[1])
- })
+ colmap[ic]["dialect_options"].update(
+ {
+ "mssql_identity_start": int(row[0]),
+ "mssql_identity_increment": int(row[1]),
+ }
+ )
return cols
@reflection.cache
@_db_plus_owner
- def get_pk_constraint(self, connection, tablename,
- dbname, owner, schema, **kw):
+ def get_pk_constraint(
+ self, connection, tablename, dbname, owner, schema, **kw
+ ):
pkeys = []
TC = ischema.constraints
- C = ischema.key_constraints.alias('C')
+ C = ischema.key_constraints.alias("C")
# Primary key constraints
- s = sql.select([C.c.column_name,
- TC.c.constraint_type,
- C.c.constraint_name],
- sql.and_(TC.c.constraint_name == C.c.constraint_name,
- TC.c.table_schema == C.c.table_schema,
- C.c.table_name == tablename,
- C.c.table_schema == owner)
- )
+ s = sql.select(
+ [C.c.column_name, TC.c.constraint_type, C.c.constraint_name],
+ sql.and_(
+ TC.c.constraint_name == C.c.constraint_name,
+ TC.c.table_schema == C.c.table_schema,
+ C.c.table_name == tablename,
+ C.c.table_schema == owner,
+ ),
+ )
c = connection.execute(s)
constraint_name = None
for row in c:
- if 'PRIMARY' in row[TC.c.constraint_type.name]:
+ if "PRIMARY" in row[TC.c.constraint_type.name]:
pkeys.append(row[0])
if constraint_name is None:
constraint_name = row[C.c.constraint_name.name]
- return {'constrained_columns': pkeys, 'name': constraint_name}
+ return {"constrained_columns": pkeys, "name": constraint_name}
@reflection.cache
@_db_plus_owner
- def get_foreign_keys(self, connection, tablename,
- dbname, owner, schema, **kw):
+ def get_foreign_keys(
+ self, connection, tablename, dbname, owner, schema, **kw
+ ):
RR = ischema.ref_constraints
- C = ischema.key_constraints.alias('C')
- R = ischema.key_constraints.alias('R')
+ C = ischema.key_constraints.alias("C")
+ R = ischema.key_constraints.alias("R")
# Foreign key constraints
- s = sql.select([C.c.column_name,
- R.c.table_schema, R.c.table_name, R.c.column_name,
- RR.c.constraint_name, RR.c.match_option,
- RR.c.update_rule,
- RR.c.delete_rule],
- sql.and_(C.c.table_name == tablename,
- C.c.table_schema == owner,
- RR.c.constraint_schema == C.c.table_schema,
- C.c.constraint_name == RR.c.constraint_name,
- R.c.constraint_name ==
- RR.c.unique_constraint_name,
- R.c.constraint_schema ==
- RR.c.unique_constraint_schema,
- C.c.ordinal_position == R.c.ordinal_position
- ),
- order_by=[RR.c.constraint_name, R.c.ordinal_position]
- )
+ s = sql.select(
+ [
+ C.c.column_name,
+ R.c.table_schema,
+ R.c.table_name,
+ R.c.column_name,
+ RR.c.constraint_name,
+ RR.c.match_option,
+ RR.c.update_rule,
+ RR.c.delete_rule,
+ ],
+ sql.and_(
+ C.c.table_name == tablename,
+ C.c.table_schema == owner,
+ RR.c.constraint_schema == C.c.table_schema,
+ C.c.constraint_name == RR.c.constraint_name,
+ R.c.constraint_name == RR.c.unique_constraint_name,
+ R.c.constraint_schema == RR.c.unique_constraint_schema,
+ C.c.ordinal_position == R.c.ordinal_position,
+ ),
+ order_by=[RR.c.constraint_name, R.c.ordinal_position],
+ )
# group rows by constraint ID, to handle multi-column FKs
fkeys = []
@@ -2327,11 +2630,11 @@ class MSDialect(default.DefaultDialect):
def fkey_rec():
return {
- 'name': None,
- 'constrained_columns': [],
- 'referred_schema': None,
- 'referred_table': None,
- 'referred_columns': []
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
}
fkeys = util.defaultdict(fkey_rec)
@@ -2340,17 +2643,18 @@ class MSDialect(default.DefaultDialect):
scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
rec = fkeys[rfknm]
- rec['name'] = rfknm
- if not rec['referred_table']:
- rec['referred_table'] = rtbl
+ rec["name"] = rfknm
+ if not rec["referred_table"]:
+ rec["referred_table"] = rtbl
if schema is not None or owner != rschema:
if dbname:
rschema = dbname + "." + rschema
- rec['referred_schema'] = rschema
+ rec["referred_schema"] = rschema
- local_cols, remote_cols = \
- rec['constrained_columns'],\
- rec['referred_columns']
+ local_cols, remote_cols = (
+ rec["constrained_columns"],
+ rec["referred_columns"],
+ )
local_cols.append(scol)
remote_cols.append(rcol)
diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py
index 3682fae48..c4ea8ab0c 100644
--- a/lib/sqlalchemy/dialects/mssql/information_schema.py
+++ b/lib/sqlalchemy/dialects/mssql/information_schema.py
@@ -38,102 +38,122 @@ class _cast_on_2005(expression.ColumnElement):
@compiles(_cast_on_2005)
def _compile(element, compiler, **kw):
from . import base
- if compiler.dialect.server_version_info is None or \
- compiler.dialect.server_version_info < base.MS_2005_VERSION:
+
+ if (
+ compiler.dialect.server_version_info is None
+ or compiler.dialect.server_version_info < base.MS_2005_VERSION
+ ):
return compiler.process(element.bindvalue, **kw)
else:
return compiler.process(cast(element.bindvalue, Unicode), **kw)
-schemata = Table("SCHEMATA", ischema,
- Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
- Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
- Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
- schema="INFORMATION_SCHEMA")
-
-tables = Table("TABLES", ischema,
- Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
- Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
- Column("TABLE_NAME", CoerceUnicode, key="table_name"),
- Column(
- "TABLE_TYPE", String(convert_unicode=True),
- key="table_type"),
- schema="INFORMATION_SCHEMA")
-
-columns = Table("COLUMNS", ischema,
- Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
- Column("TABLE_NAME", CoerceUnicode, key="table_name"),
- Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
- Column("IS_NULLABLE", Integer, key="is_nullable"),
- Column("DATA_TYPE", String, key="data_type"),
- Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
- Column("CHARACTER_MAXIMUM_LENGTH", Integer,
- key="character_maximum_length"),
- Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
- Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
- Column("COLUMN_DEFAULT", Integer, key="column_default"),
- Column("COLLATION_NAME", String, key="collation_name"),
- schema="INFORMATION_SCHEMA")
-
-constraints = Table("TABLE_CONSTRAINTS", ischema,
- Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
- Column("TABLE_NAME", CoerceUnicode, key="table_name"),
- Column("CONSTRAINT_NAME", CoerceUnicode,
- key="constraint_name"),
- Column("CONSTRAINT_TYPE", String(
- convert_unicode=True), key="constraint_type"),
- schema="INFORMATION_SCHEMA")
-
-column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
- Column("TABLE_SCHEMA", CoerceUnicode,
- key="table_schema"),
- Column("TABLE_NAME", CoerceUnicode,
- key="table_name"),
- Column("COLUMN_NAME", CoerceUnicode,
- key="column_name"),
- Column("CONSTRAINT_NAME", CoerceUnicode,
- key="constraint_name"),
- schema="INFORMATION_SCHEMA")
-
-key_constraints = Table("KEY_COLUMN_USAGE", ischema,
- Column("TABLE_SCHEMA", CoerceUnicode,
- key="table_schema"),
- Column("TABLE_NAME", CoerceUnicode,
- key="table_name"),
- Column("COLUMN_NAME", CoerceUnicode,
- key="column_name"),
- Column("CONSTRAINT_NAME", CoerceUnicode,
- key="constraint_name"),
- Column("CONSTRAINT_SCHEMA", CoerceUnicode,
- key="constraint_schema"),
- Column("ORDINAL_POSITION", Integer,
- key="ordinal_position"),
- schema="INFORMATION_SCHEMA")
-
-ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
- Column("CONSTRAINT_CATALOG", CoerceUnicode,
- key="constraint_catalog"),
- Column("CONSTRAINT_SCHEMA", CoerceUnicode,
- key="constraint_schema"),
- Column("CONSTRAINT_NAME", CoerceUnicode,
- key="constraint_name"),
- # TODO: is CATLOG misspelled ?
- Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode,
- key="unique_constraint_catalog"),
-
- Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode,
- key="unique_constraint_schema"),
- Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode,
- key="unique_constraint_name"),
- Column("MATCH_OPTION", String, key="match_option"),
- Column("UPDATE_RULE", String, key="update_rule"),
- Column("DELETE_RULE", String, key="delete_rule"),
- schema="INFORMATION_SCHEMA")
-
-views = Table("VIEWS", ischema,
- Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
- Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
- Column("TABLE_NAME", CoerceUnicode, key="table_name"),
- Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
- Column("CHECK_OPTION", String, key="check_option"),
- Column("IS_UPDATABLE", String, key="is_updatable"),
- schema="INFORMATION_SCHEMA")
+
+schemata = Table(
+ "SCHEMATA",
+ ischema,
+ Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
+ Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
+ Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
+ schema="INFORMATION_SCHEMA",
+)
+
+tables = Table(
+ "TABLES",
+ ischema,
+ Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"),
+ schema="INFORMATION_SCHEMA",
+)
+
+columns = Table(
+ "COLUMNS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column(
+ "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+ ),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="INFORMATION_SCHEMA",
+)
+
+constraints = Table(
+ "TABLE_CONSTRAINTS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column(
+ "CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"
+ ),
+ schema="INFORMATION_SCHEMA",
+)
+
+column_constraints = Table(
+ "CONSTRAINT_COLUMN_USAGE",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ schema="INFORMATION_SCHEMA",
+)
+
+key_constraints = Table(
+ "KEY_COLUMN_USAGE",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ schema="INFORMATION_SCHEMA",
+)
+
+ref_constraints = Table(
+ "REFERENTIAL_CONSTRAINTS",
+ ischema,
+ Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
+ Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ # TODO: is CATLOG misspelled ?
+ Column(
+ "UNIQUE_CONSTRAINT_CATLOG",
+ CoerceUnicode,
+ key="unique_constraint_catalog",
+ ),
+ Column(
+ "UNIQUE_CONSTRAINT_SCHEMA",
+ CoerceUnicode,
+ key="unique_constraint_schema",
+ ),
+ Column(
+ "UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"
+ ),
+ Column("MATCH_OPTION", String, key="match_option"),
+ Column("UPDATE_RULE", String, key="update_rule"),
+ Column("DELETE_RULE", String, key="delete_rule"),
+ schema="INFORMATION_SCHEMA",
+)
+
+views = Table(
+ "VIEWS",
+ ischema,
+ Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
+ Column("CHECK_OPTION", String, key="check_option"),
+ Column("IS_UPDATABLE", String, key="is_updatable"),
+ schema="INFORMATION_SCHEMA",
+)
diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py
index 8983a3b60..3b9ea2707 100644
--- a/lib/sqlalchemy/dialects/mssql/mxodbc.py
+++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py
@@ -46,10 +46,14 @@ of ``False`` will unconditionally use string-escaped parameters.
from ... import types as sqltypes
from ...connectors.mxodbc import MxODBCConnector
from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc
-from .base import (MSDialect,
- MSSQLStrictCompiler,
- VARBINARY,
- _MSDateTime, _MSDate, _MSTime)
+from .base import (
+ MSDialect,
+ MSSQLStrictCompiler,
+ VARBINARY,
+ _MSDateTime,
+ _MSDate,
+ _MSTime,
+)
class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
@@ -64,6 +68,7 @@ class _MSDate_mxodbc(_MSDate):
return "%s-%s-%s" % (value.year, value.month, value.day)
else:
return None
+
return process
@@ -74,6 +79,7 @@ class _MSTime_mxodbc(_MSTime):
return "%s:%s:%s" % (value.hour, value.minute, value.second)
else:
return None
+
return process
@@ -98,6 +104,7 @@ class _VARBINARY_mxodbc(VARBINARY):
else:
# should pull from mx.ODBC.Manager.BinaryNull
return dialect.dbapi.BinaryNull
+
return process
@@ -107,6 +114,7 @@ class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
SELECT SCOPE_IDENTITY in cases where OUTPUT clause
does not work (tables with insert triggers).
"""
+
# todo - investigate whether the pyodbc execution context
# is really only being used in cases where OUTPUT
# won't work.
@@ -136,4 +144,5 @@ class MSDialect_mxodbc(MxODBCConnector, MSDialect):
super(MSDialect_mxodbc, self).__init__(**params)
self.description_encoding = description_encoding
+
dialect = MSDialect_mxodbc
diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py
index 8589c8b06..847c00329 100644
--- a/lib/sqlalchemy/dialects/mssql/pymssql.py
+++ b/lib/sqlalchemy/dialects/mssql/pymssql.py
@@ -35,7 +35,6 @@ class _MSNumeric_pymssql(sqltypes.Numeric):
class MSIdentifierPreparer_pymssql(MSIdentifierPreparer):
-
def __init__(self, dialect):
super(MSIdentifierPreparer_pymssql, self).__init__(dialect)
# pymssql has the very unusual behavior that it uses pyformat
@@ -45,47 +44,45 @@ class MSIdentifierPreparer_pymssql(MSIdentifierPreparer):
class MSDialect_pymssql(MSDialect):
supports_native_decimal = True
- driver = 'pymssql'
+ driver = "pymssql"
preparer = MSIdentifierPreparer_pymssql
colspecs = util.update_copy(
MSDialect.colspecs,
- {
- sqltypes.Numeric: _MSNumeric_pymssql,
- sqltypes.Float: sqltypes.Float,
- }
+ {sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float},
)
@classmethod
def dbapi(cls):
- module = __import__('pymssql')
+ module = __import__("pymssql")
# pymmsql < 2.1.1 doesn't have a Binary method. we use string
client_ver = tuple(int(x) for x in module.__version__.split("."))
if client_ver < (2, 1, 1):
# TODO: monkeypatching here is less than ideal
- module.Binary = lambda x: x if hasattr(x, 'decode') else str(x)
+ module.Binary = lambda x: x if hasattr(x, "decode") else str(x)
- if client_ver < (1, ):
- util.warn("The pymssql dialect expects at least "
- "the 1.0 series of the pymssql DBAPI.")
+ if client_ver < (1,):
+ util.warn(
+ "The pymssql dialect expects at least "
+ "the 1.0 series of the pymssql DBAPI."
+ )
return module
def _get_server_version_info(self, connection):
vers = connection.scalar("select @@version")
- m = re.match(
- r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers)
+ m = re.match(r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers)
if m:
return tuple(int(x) for x in m.group(1, 2, 3, 4))
else:
return None
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
+ opts = url.translate_connect_args(username="user")
opts.update(url.query)
- port = opts.pop('port', None)
- if port and 'host' in opts:
- opts['host'] = "%s:%s" % (opts['host'], port)
+ port = opts.pop("port", None)
+ if port and "host" in opts:
+ opts["host"] = "%s:%s" % (opts["host"], port)
return [[], opts]
def is_disconnect(self, e, connection, cursor):
@@ -105,12 +102,13 @@ class MSDialect_pymssql(MSDialect):
return False
def set_isolation_level(self, connection, level):
- if level == 'AUTOCOMMIT':
+ if level == "AUTOCOMMIT":
connection.autocommit(True)
else:
connection.autocommit(False)
- super(MSDialect_pymssql, self).set_isolation_level(connection,
- level)
+ super(MSDialect_pymssql, self).set_isolation_level(
+ connection, level
+ )
dialect = MSDialect_pymssql
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
index 34f81d6e8..db5573c2c 100644
--- a/lib/sqlalchemy/dialects/mssql/pyodbc.py
+++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py
@@ -132,15 +132,13 @@ class _ms_numeric_pyodbc(object):
def bind_processor(self, dialect):
- super_process = super(_ms_numeric_pyodbc, self).\
- bind_processor(dialect)
+ super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect)
if not dialect._need_decimal_fix:
return super_process
def process(value):
- if self.asdecimal and \
- isinstance(value, decimal.Decimal):
+ if self.asdecimal and isinstance(value, decimal.Decimal):
adjusted = value.adjusted()
if adjusted < 0:
return self._small_dec_to_string(value)
@@ -151,6 +149,7 @@ class _ms_numeric_pyodbc(object):
return super_process(value)
else:
return value
+
return process
# these routines needed for older versions of pyodbc.
@@ -158,30 +157,31 @@ class _ms_numeric_pyodbc(object):
def _small_dec_to_string(self, value):
return "%s0.%s%s" % (
- (value < 0 and '-' or ''),
- '0' * (abs(value.adjusted()) - 1),
- "".join([str(nint) for nint in value.as_tuple()[1]]))
+ (value < 0 and "-" or ""),
+ "0" * (abs(value.adjusted()) - 1),
+ "".join([str(nint) for nint in value.as_tuple()[1]]),
+ )
def _large_dec_to_string(self, value):
_int = value.as_tuple()[1]
- if 'E' in str(value):
+ if "E" in str(value):
result = "%s%s%s" % (
- (value < 0 and '-' or ''),
+ (value < 0 and "-" or ""),
"".join([str(s) for s in _int]),
- "0" * (value.adjusted() - (len(_int) - 1)))
+ "0" * (value.adjusted() - (len(_int) - 1)),
+ )
else:
if (len(_int) - 1) > value.adjusted():
result = "%s%s.%s" % (
- (value < 0 and '-' or ''),
- "".join(
- [str(s) for s in _int][0:value.adjusted() + 1]),
- "".join(
- [str(s) for s in _int][value.adjusted() + 1:]))
+ (value < 0 and "-" or ""),
+ "".join([str(s) for s in _int][0 : value.adjusted() + 1]),
+ "".join([str(s) for s in _int][value.adjusted() + 1 :]),
+ )
else:
result = "%s%s" % (
- (value < 0 and '-' or ''),
- "".join(
- [str(s) for s in _int][0:value.adjusted() + 1]))
+ (value < 0 and "-" or ""),
+ "".join([str(s) for s in _int][0 : value.adjusted() + 1]),
+ )
return result
@@ -212,6 +212,7 @@ class _ms_binary_pyodbc(object):
else:
# pyodbc-specific
return dialect.dbapi.BinaryNull
+
return process
@@ -243,9 +244,11 @@ class MSExecutionContext_pyodbc(MSExecutionContext):
# don't embed the scope_identity select into an
# "INSERT .. DEFAULT VALUES"
- if self._select_lastrowid and \
- self.dialect.use_scope_identity and \
- len(self.parameters[0]):
+ if (
+ self._select_lastrowid
+ and self.dialect.use_scope_identity
+ and len(self.parameters[0])
+ ):
self._embedded_scope_identity = True
self.statement += "; select scope_identity()"
@@ -281,26 +284,31 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
sqltypes.Numeric: _MSNumeric_pyodbc,
sqltypes.Float: _MSFloat_pyodbc,
BINARY: _BINARY_pyodbc,
-
# SQL Server dialect has a VARBINARY that is just to support
# "deprecate_large_types" w/ VARBINARY(max), but also we must
# handle the usual SQL standard VARBINARY
VARBINARY: _VARBINARY_pyodbc,
sqltypes.VARBINARY: _VARBINARY_pyodbc,
sqltypes.LargeBinary: _VARBINARY_pyodbc,
- }
+ },
)
- def __init__(self, description_encoding=None, fast_executemany=False,
- **params):
- if 'description_encoding' in params:
- self.description_encoding = params.pop('description_encoding')
+ def __init__(
+ self, description_encoding=None, fast_executemany=False, **params
+ ):
+ if "description_encoding" in params:
+ self.description_encoding = params.pop("description_encoding")
super(MSDialect_pyodbc, self).__init__(**params)
- self.use_scope_identity = self.use_scope_identity and \
- self.dbapi and \
- hasattr(self.dbapi.Cursor, 'nextset')
- self._need_decimal_fix = self.dbapi and \
- self._dbapi_version() < (2, 1, 8)
+ self.use_scope_identity = (
+ self.use_scope_identity
+ and self.dbapi
+ and hasattr(self.dbapi.Cursor, "nextset")
+ )
+ self._need_decimal_fix = self.dbapi and self._dbapi_version() < (
+ 2,
+ 1,
+ 8,
+ )
self.fast_executemany = fast_executemany
def _get_server_version_info(self, connection):
@@ -308,16 +316,18 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
# "Version of the instance of SQL Server, in the form
# of 'major.minor.build.revision'"
raw = connection.scalar(
- "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)")
+ "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)"
+ )
except exc.DBAPIError:
# SQL Server docs indicate this function isn't present prior to
# 2008. Before we had the VARCHAR cast above, pyodbc would also
# fail on this query.
- return super(MSDialect_pyodbc, self).\
- _get_server_version_info(connection, allow_chars=False)
+ return super(MSDialect_pyodbc, self)._get_server_version_info(
+ connection, allow_chars=False
+ )
else:
version = []
- r = re.compile(r'[.\-]')
+ r = re.compile(r"[.\-]")
for n in r.split(raw):
try:
version.append(int(n))
@@ -329,17 +339,27 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
if self.fast_executemany:
cursor.fast_executemany = True
super(MSDialect_pyodbc, self).do_executemany(
- cursor, statement, parameters, context=context)
+ cursor, statement, parameters, context=context
+ )
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
for code in (
- '08S01', '01002', '08003', '08007',
- '08S02', '08001', 'HYT00', 'HY010',
- '10054'):
+ "08S01",
+ "01002",
+ "08003",
+ "08007",
+ "08S02",
+ "08001",
+ "HYT00",
+ "HY010",
+ "10054",
+ ):
if code in str(e):
return True
return super(MSDialect_pyodbc, self).is_disconnect(
- e, connection, cursor)
+ e, connection, cursor
+ )
+
dialect = MSDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/mssql/zxjdbc.py b/lib/sqlalchemy/dialects/mssql/zxjdbc.py
index 3fb93b28a..13fc46e19 100644
--- a/lib/sqlalchemy/dialects/mssql/zxjdbc.py
+++ b/lib/sqlalchemy/dialects/mssql/zxjdbc.py
@@ -44,26 +44,28 @@ class MSExecutionContext_zxjdbc(MSExecutionContext):
self.cursor.nextset()
self._lastrowid = int(row[0])
- if (self.isinsert or self.isupdate or self.isdelete) and \
- self.compiled.returning:
+ if (
+ self.isinsert or self.isupdate or self.isdelete
+ ) and self.compiled.returning:
self._result_proxy = engine.FullyBufferedResultProxy(self)
if self._enable_identity_insert:
table = self.dialect.identifier_preparer.format_table(
- self.compiled.statement.table)
+ self.compiled.statement.table
+ )
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table)
class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect):
- jdbc_db_name = 'jtds:sqlserver'
- jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver'
+ jdbc_db_name = "jtds:sqlserver"
+ jdbc_driver_name = "net.sourceforge.jtds.jdbc.Driver"
execution_ctx_cls = MSExecutionContext_zxjdbc
def _get_server_version_info(self, connection):
return tuple(
- int(x)
- for x in connection.connection.dbversion.split('.')
+ int(x) for x in connection.connection.dbversion.split(".")
)
+
dialect = MSDialect_zxjdbc
diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py
index de4e1fa41..ffeb8f486 100644
--- a/lib/sqlalchemy/dialects/mysql/__init__.py
+++ b/lib/sqlalchemy/dialects/mysql/__init__.py
@@ -5,18 +5,56 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from . import base, mysqldb, oursql, \
- pyodbc, zxjdbc, mysqlconnector, pymysql, \
- gaerdbms, cymysql
+from . import (
+ base,
+ mysqldb,
+ oursql,
+ pyodbc,
+ zxjdbc,
+ mysqlconnector,
+ pymysql,
+ gaerdbms,
+ cymysql,
+)
-from .base import \
- BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \
- DECIMAL, DOUBLE, ENUM, DECIMAL,\
- FLOAT, INTEGER, INTEGER, JSON, LONGBLOB, LONGTEXT, MEDIUMBLOB, \
- MEDIUMINT, MEDIUMTEXT, NCHAR, \
- NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \
- TINYBLOB, TINYINT, TINYTEXT,\
- VARBINARY, VARCHAR, YEAR
+from .base import (
+ BIGINT,
+ BINARY,
+ BIT,
+ BLOB,
+ BOOLEAN,
+ CHAR,
+ DATE,
+ DATETIME,
+ DECIMAL,
+ DOUBLE,
+ ENUM,
+ DECIMAL,
+ FLOAT,
+ INTEGER,
+ INTEGER,
+ JSON,
+ LONGBLOB,
+ LONGTEXT,
+ MEDIUMBLOB,
+ MEDIUMINT,
+ MEDIUMTEXT,
+ NCHAR,
+ NVARCHAR,
+ NUMERIC,
+ SET,
+ SMALLINT,
+ REAL,
+ TEXT,
+ TIME,
+ TIMESTAMP,
+ TINYBLOB,
+ TINYINT,
+ TINYTEXT,
+ VARBINARY,
+ VARCHAR,
+ YEAR,
+)
from .dml import insert, Insert
@@ -25,10 +63,41 @@ base.dialect = dialect = mysqldb.dialect
__all__ = (
- 'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME',
- 'DECIMAL', 'DOUBLE', 'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER',
- 'JSON', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', 'MEDIUMTEXT',
- 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME',
- 'TIMESTAMP', 'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR',
- 'YEAR', 'dialect'
+ "BIGINT",
+ "BINARY",
+ "BIT",
+ "BLOB",
+ "BOOLEAN",
+ "CHAR",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "DOUBLE",
+ "ENUM",
+ "DECIMAL",
+ "FLOAT",
+ "INTEGER",
+ "INTEGER",
+ "JSON",
+ "LONGBLOB",
+ "LONGTEXT",
+ "MEDIUMBLOB",
+ "MEDIUMINT",
+ "MEDIUMTEXT",
+ "NCHAR",
+ "NVARCHAR",
+ "NUMERIC",
+ "SET",
+ "SMALLINT",
+ "REAL",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "TINYBLOB",
+ "TINYINT",
+ "TINYTEXT",
+ "VARBINARY",
+ "VARCHAR",
+ "YEAR",
+ "dialect",
)
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 673d4b9ff..7b0d0618c 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -746,85 +746,340 @@ from ...engine import reflection
from ...engine import default
from ... import types as sqltypes
from ...util import topological
-from ...types import DATE, BOOLEAN, \
- BLOB, BINARY, VARBINARY
+from ...types import DATE, BOOLEAN, BLOB, BINARY, VARBINARY
from . import reflection as _reflection
-from .types import BIGINT, BIT, CHAR, DECIMAL, DATETIME, \
- DOUBLE, FLOAT, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, \
- MEDIUMTEXT, NCHAR, NUMERIC, NVARCHAR, REAL, SMALLINT, TEXT, TIME, \
- TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT, VARCHAR, YEAR
-from .types import _StringType, _IntegerType, _NumericType, \
- _FloatType, _MatchType
+from .types import (
+ BIGINT,
+ BIT,
+ CHAR,
+ DECIMAL,
+ DATETIME,
+ DOUBLE,
+ FLOAT,
+ INTEGER,
+ LONGBLOB,
+ LONGTEXT,
+ MEDIUMBLOB,
+ MEDIUMINT,
+ MEDIUMTEXT,
+ NCHAR,
+ NUMERIC,
+ NVARCHAR,
+ REAL,
+ SMALLINT,
+ TEXT,
+ TIME,
+ TIMESTAMP,
+ TINYBLOB,
+ TINYINT,
+ TINYTEXT,
+ VARCHAR,
+ YEAR,
+)
+from .types import (
+ _StringType,
+ _IntegerType,
+ _NumericType,
+ _FloatType,
+ _MatchType,
+)
from .enumerated import ENUM, SET
from .json import JSON, JSONIndexType, JSONPathType
RESERVED_WORDS = set(
- ['accessible', 'add', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
- 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
- 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check',
- 'collate', 'column', 'condition', 'constraint', 'continue', 'convert',
- 'create', 'cross', 'current_date', 'current_time', 'current_timestamp',
- 'current_user', 'cursor', 'database', 'databases', 'day_hour',
- 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal',
- 'declare', 'default', 'delayed', 'delete', 'desc', 'describe',
- 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop',
- 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists',
- 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8',
- 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group',
- 'having', 'high_priority', 'hour_microsecond', 'hour_minute',
- 'hour_second', 'if', 'ignore', 'in', 'index', 'infile', 'inner', 'inout',
- 'insensitive', 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8',
- 'integer', 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys',
- 'kill', 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines',
- 'load', 'localtime', 'localtimestamp', 'lock', 'long', 'longblob',
- 'longtext', 'loop', 'low_priority', 'master_ssl_verify_server_cert',
- 'match', 'mediumblob', 'mediumint', 'mediumtext', 'middleint',
- 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural',
- 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize',
- 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile',
- 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads',
- 'read_only', 'read_write', 'real', 'references', 'regexp', 'release',
- 'rename', 'repeat', 'replace', 'require', 'restrict', 'return',
- 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond',
- 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial',
- 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning',
- 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl',
- 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob',
- 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo',
- 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use',
- 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary',
- 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with',
-
- 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0
-
- 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1
-
- 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
- 'read_only', 'read_write', # 5.1
-
- 'general', 'ignore_server_ids', 'master_heartbeat_period', 'maxvalue',
- 'resignal', 'signal', 'slow', # 5.5
-
- 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot',
- 'partition', 'sql_after_gtids', 'sql_before_gtids', # 5.6
-
- 'generated', 'optimizer_costs', 'stored', 'virtual', # 5.7
-
- 'admin', 'cume_dist', 'empty', 'except', 'first_value', 'grouping',
- 'function', 'groups', 'json_table', 'last_value', 'nth_value',
- 'ntile', 'of', 'over', 'percent_rank', 'persist', 'persist_only',
- 'rank', 'recursive', 'role', 'row', 'rows', 'row_number', 'system',
- 'window', # 8.0
- ])
+ [
+ "accessible",
+ "add",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "as",
+ "asc",
+ "asensitive",
+ "before",
+ "between",
+ "bigint",
+ "binary",
+ "blob",
+ "both",
+ "by",
+ "call",
+ "cascade",
+ "case",
+ "change",
+ "char",
+ "character",
+ "check",
+ "collate",
+ "column",
+ "condition",
+ "constraint",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "databases",
+ "day_hour",
+ "day_microsecond",
+ "day_minute",
+ "day_second",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delayed",
+ "delete",
+ "desc",
+ "describe",
+ "deterministic",
+ "distinct",
+ "distinctrow",
+ "div",
+ "double",
+ "drop",
+ "dual",
+ "each",
+ "else",
+ "elseif",
+ "enclosed",
+ "escaped",
+ "exists",
+ "exit",
+ "explain",
+ "false",
+ "fetch",
+ "float",
+ "float4",
+ "float8",
+ "for",
+ "force",
+ "foreign",
+ "from",
+ "fulltext",
+ "grant",
+ "group",
+ "having",
+ "high_priority",
+ "hour_microsecond",
+ "hour_minute",
+ "hour_second",
+ "if",
+ "ignore",
+ "in",
+ "index",
+ "infile",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "int",
+ "int1",
+ "int2",
+ "int3",
+ "int4",
+ "int8",
+ "integer",
+ "interval",
+ "into",
+ "is",
+ "iterate",
+ "join",
+ "key",
+ "keys",
+ "kill",
+ "leading",
+ "leave",
+ "left",
+ "like",
+ "limit",
+ "linear",
+ "lines",
+ "load",
+ "localtime",
+ "localtimestamp",
+ "lock",
+ "long",
+ "longblob",
+ "longtext",
+ "loop",
+ "low_priority",
+ "master_ssl_verify_server_cert",
+ "match",
+ "mediumblob",
+ "mediumint",
+ "mediumtext",
+ "middleint",
+ "minute_microsecond",
+ "minute_second",
+ "mod",
+ "modifies",
+ "natural",
+ "not",
+ "no_write_to_binlog",
+ "null",
+ "numeric",
+ "on",
+ "optimize",
+ "option",
+ "optionally",
+ "or",
+ "order",
+ "out",
+ "outer",
+ "outfile",
+ "precision",
+ "primary",
+ "procedure",
+ "purge",
+ "range",
+ "read",
+ "reads",
+ "read_only",
+ "read_write",
+ "real",
+ "references",
+ "regexp",
+ "release",
+ "rename",
+ "repeat",
+ "replace",
+ "require",
+ "restrict",
+ "return",
+ "revoke",
+ "right",
+ "rlike",
+ "schema",
+ "schemas",
+ "second_microsecond",
+ "select",
+ "sensitive",
+ "separator",
+ "set",
+ "show",
+ "smallint",
+ "spatial",
+ "specific",
+ "sql",
+ "sqlexception",
+ "sqlstate",
+ "sqlwarning",
+ "sql_big_result",
+ "sql_calc_found_rows",
+ "sql_small_result",
+ "ssl",
+ "starting",
+ "straight_join",
+ "table",
+ "terminated",
+ "then",
+ "tinyblob",
+ "tinyint",
+ "tinytext",
+ "to",
+ "trailing",
+ "trigger",
+ "true",
+ "undo",
+ "union",
+ "unique",
+ "unlock",
+ "unsigned",
+ "update",
+ "usage",
+ "use",
+ "using",
+ "utc_date",
+ "utc_time",
+ "utc_timestamp",
+ "values",
+ "varbinary",
+ "varchar",
+ "varcharacter",
+ "varying",
+ "when",
+ "where",
+ "while",
+ "with",
+ "write",
+ "x509",
+ "xor",
+ "year_month",
+ "zerofill", # 5.0
+ "columns",
+ "fields",
+ "privileges",
+ "soname",
+ "tables", # 4.1
+ "accessible",
+ "linear",
+ "master_ssl_verify_server_cert",
+ "range",
+ "read_only",
+ "read_write", # 5.1
+ "general",
+ "ignore_server_ids",
+ "master_heartbeat_period",
+ "maxvalue",
+ "resignal",
+ "signal",
+ "slow", # 5.5
+ "get",
+ "io_after_gtids",
+ "io_before_gtids",
+ "master_bind",
+ "one_shot",
+ "partition",
+ "sql_after_gtids",
+ "sql_before_gtids", # 5.6
+ "generated",
+ "optimizer_costs",
+ "stored",
+ "virtual", # 5.7
+ "admin",
+ "cume_dist",
+ "empty",
+ "except",
+ "first_value",
+ "grouping",
+ "function",
+ "groups",
+ "json_table",
+ "last_value",
+ "nth_value",
+ "ntile",
+ "of",
+ "over",
+ "percent_rank",
+ "persist",
+ "persist_only",
+ "rank",
+ "recursive",
+ "role",
+ "row",
+ "rows",
+ "row_number",
+ "system",
+ "window", # 8.0
+ ]
+)
AUTOCOMMIT_RE = re.compile(
- r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)',
- re.I | re.UNICODE)
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)",
+ re.I | re.UNICODE,
+)
SET_RE = re.compile(
- r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w',
- re.I | re.UNICODE)
+ r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
+)
# old names
@@ -870,52 +1125,50 @@ colspecs = {
sqltypes.MatchType: _MatchType,
sqltypes.JSON: JSON,
sqltypes.JSON.JSONIndexType: JSONIndexType,
- sqltypes.JSON.JSONPathType: JSONPathType
-
+ sqltypes.JSON.JSONPathType: JSONPathType,
}
# Everything 3.23 through 5.1 excepting OpenGIS types.
ischema_names = {
- 'bigint': BIGINT,
- 'binary': BINARY,
- 'bit': BIT,
- 'blob': BLOB,
- 'boolean': BOOLEAN,
- 'char': CHAR,
- 'date': DATE,
- 'datetime': DATETIME,
- 'decimal': DECIMAL,
- 'double': DOUBLE,
- 'enum': ENUM,
- 'fixed': DECIMAL,
- 'float': FLOAT,
- 'int': INTEGER,
- 'integer': INTEGER,
- 'json': JSON,
- 'longblob': LONGBLOB,
- 'longtext': LONGTEXT,
- 'mediumblob': MEDIUMBLOB,
- 'mediumint': MEDIUMINT,
- 'mediumtext': MEDIUMTEXT,
- 'nchar': NCHAR,
- 'nvarchar': NVARCHAR,
- 'numeric': NUMERIC,
- 'set': SET,
- 'smallint': SMALLINT,
- 'text': TEXT,
- 'time': TIME,
- 'timestamp': TIMESTAMP,
- 'tinyblob': TINYBLOB,
- 'tinyint': TINYINT,
- 'tinytext': TINYTEXT,
- 'varbinary': VARBINARY,
- 'varchar': VARCHAR,
- 'year': YEAR,
+ "bigint": BIGINT,
+ "binary": BINARY,
+ "bit": BIT,
+ "blob": BLOB,
+ "boolean": BOOLEAN,
+ "char": CHAR,
+ "date": DATE,
+ "datetime": DATETIME,
+ "decimal": DECIMAL,
+ "double": DOUBLE,
+ "enum": ENUM,
+ "fixed": DECIMAL,
+ "float": FLOAT,
+ "int": INTEGER,
+ "integer": INTEGER,
+ "json": JSON,
+ "longblob": LONGBLOB,
+ "longtext": LONGTEXT,
+ "mediumblob": MEDIUMBLOB,
+ "mediumint": MEDIUMINT,
+ "mediumtext": MEDIUMTEXT,
+ "nchar": NCHAR,
+ "nvarchar": NVARCHAR,
+ "numeric": NUMERIC,
+ "set": SET,
+ "smallint": SMALLINT,
+ "text": TEXT,
+ "time": TIME,
+ "timestamp": TIMESTAMP,
+ "tinyblob": TINYBLOB,
+ "tinyint": TINYINT,
+ "tinytext": TINYTEXT,
+ "varbinary": VARBINARY,
+ "varchar": VARCHAR,
+ "year": YEAR,
}
class MySQLExecutionContext(default.DefaultExecutionContext):
-
def should_autocommit_text(self, statement):
return AUTOCOMMIT_RE.match(statement)
@@ -932,7 +1185,7 @@ class MySQLCompiler(compiler.SQLCompiler):
"""Overridden from base SQLCompiler value"""
extract_map = compiler.SQLCompiler.extract_map.copy()
- extract_map.update({'milliseconds': 'millisecond'})
+ extract_map.update({"milliseconds": "millisecond"})
def visit_random_func(self, fn, **kw):
return "rand%s" % self.function_argspec(fn)
@@ -943,12 +1196,14 @@ class MySQLCompiler(compiler.SQLCompiler):
def visit_json_getitem_op_binary(self, binary, operator, **kw):
return "JSON_EXTRACT(%s, %s)" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
return "JSON_EXTRACT(%s, %s)" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def visit_on_duplicate_key_update(self, on_duplicate, **kw):
if on_duplicate._parameter_ordering:
@@ -958,7 +1213,8 @@ class MySQLCompiler(compiler.SQLCompiler):
]
ordered_keys = set(parameter_ordering)
cols = [
- self.statement.table.c[key] for key in parameter_ordering
+ self.statement.table.c[key]
+ for key in parameter_ordering
if key in self.statement.table.c
] + [
c for c in self.statement.table.c if c.key not in ordered_keys
@@ -979,9 +1235,11 @@ class MySQLCompiler(compiler.SQLCompiler):
val = val._clone()
val.type = column.type
value_text = self.process(val.self_group(), use_schema=False)
- elif isinstance(val, elements.ColumnClause) \
- and val.table is on_duplicate.inserted_alias:
- value_text = 'VALUES(' + self.preparer.quote(column.name) + ')'
+ elif (
+ isinstance(val, elements.ColumnClause)
+ and val.table is on_duplicate.inserted_alias
+ ):
+ value_text = "VALUES(" + self.preparer.quote(column.name) + ")"
else:
value_text = self.process(val.self_group(), use_schema=False)
name_text = self.preparer.quote(column.name)
@@ -990,22 +1248,27 @@ class MySQLCompiler(compiler.SQLCompiler):
non_matching = set(on_duplicate.update) - set(c.key for c in cols)
if non_matching:
util.warn(
- 'Additional column names not matching '
- "any column keys in table '%s': %s" % (
+ "Additional column names not matching "
+ "any column keys in table '%s': %s"
+ % (
self.statement.table.name,
- (', '.join("'%s'" % c for c in non_matching))
+ (", ".join("'%s'" % c for c in non_matching)),
)
)
- return 'ON DUPLICATE KEY UPDATE ' + ', '.join(clauses)
+ return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses)
def visit_concat_op_binary(self, binary, operator, **kw):
- return "concat(%s, %s)" % (self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ return "concat(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
def visit_match_op_binary(self, binary, operator, **kw):
- return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \
- (self.process(binary.left, **kw), self.process(binary.right, **kw))
+ return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
def get_from_hint_text(self, table, text):
return text
@@ -1016,26 +1279,35 @@ class MySQLCompiler(compiler.SQLCompiler):
if isinstance(type_, sqltypes.TypeDecorator):
return self.visit_typeclause(typeclause, type_.impl, **kw)
elif isinstance(type_, sqltypes.Integer):
- if getattr(type_, 'unsigned', False):
- return 'UNSIGNED INTEGER'
+ if getattr(type_, "unsigned", False):
+ return "UNSIGNED INTEGER"
else:
- return 'SIGNED INTEGER'
+ return "SIGNED INTEGER"
elif isinstance(type_, sqltypes.TIMESTAMP):
- return 'DATETIME'
- elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime,
- sqltypes.Date, sqltypes.Time)):
+ return "DATETIME"
+ elif isinstance(
+ type_,
+ (
+ sqltypes.DECIMAL,
+ sqltypes.DateTime,
+ sqltypes.Date,
+ sqltypes.Time,
+ ),
+ ):
return self.dialect.type_compiler.process(type_)
- elif isinstance(type_, sqltypes.String) \
- and not isinstance(type_, (ENUM, SET)):
+ elif isinstance(type_, sqltypes.String) and not isinstance(
+ type_, (ENUM, SET)
+ ):
adapted = CHAR._adapt_string_for_cast(type_)
return self.dialect.type_compiler.process(adapted)
elif isinstance(type_, sqltypes._Binary):
- return 'BINARY'
+ return "BINARY"
elif isinstance(type_, sqltypes.JSON):
return "JSON"
elif isinstance(type_, sqltypes.NUMERIC):
- return self.dialect.type_compiler.process(
- type_).replace('NUMERIC', 'DECIMAL')
+ return self.dialect.type_compiler.process(type_).replace(
+ "NUMERIC", "DECIMAL"
+ )
else:
return None
@@ -1044,23 +1316,25 @@ class MySQLCompiler(compiler.SQLCompiler):
if not self.dialect._supports_cast:
util.warn(
"Current MySQL version does not support "
- "CAST; the CAST will be skipped.")
+ "CAST; the CAST will be skipped."
+ )
return self.process(cast.clause.self_group(), **kw)
type_ = self.process(cast.typeclause)
if type_ is None:
util.warn(
"Datatype %s does not support CAST on MySQL; "
- "the CAST will be skipped." %
- self.dialect.type_compiler.process(cast.typeclause.type))
+ "the CAST will be skipped."
+ % self.dialect.type_compiler.process(cast.typeclause.type)
+ )
return self.process(cast.clause.self_group(), **kw)
- return 'CAST(%s AS %s)' % (self.process(cast.clause, **kw), type_)
+ return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_)
def render_literal_value(self, value, type_):
value = super(MySQLCompiler, self).render_literal_value(value, type_)
if self.dialect._backslash_escapes:
- value = value.replace('\\', '\\\\')
+ value = value.replace("\\", "\\\\")
return value
# override native_boolean=False behavior here, as
@@ -1096,12 +1370,15 @@ class MySQLCompiler(compiler.SQLCompiler):
else:
join_type = " INNER JOIN "
- return ''.join(
- (self.process(join.left, asfrom=True, **kwargs),
- join_type,
- self.process(join.right, asfrom=True, **kwargs),
- " ON ",
- self.process(join.onclause, **kwargs)))
+ return "".join(
+ (
+ self.process(join.left, asfrom=True, **kwargs),
+ join_type,
+ self.process(join.right, asfrom=True, **kwargs),
+ " ON ",
+ self.process(join.onclause, **kwargs),
+ )
+ )
def for_update_clause(self, select, **kw):
if select._for_update_arg.read:
@@ -1118,11 +1395,13 @@ class MySQLCompiler(compiler.SQLCompiler):
# The latter is more readable for offsets but we're stuck with the
# former until we can refine dialects by server revision.
- limit_clause, offset_clause = select._limit_clause, \
- select._offset_clause
+ limit_clause, offset_clause = (
+ select._limit_clause,
+ select._offset_clause,
+ )
if limit_clause is None and offset_clause is None:
- return ''
+ return ""
elif offset_clause is not None:
# As suggested by the MySQL docs, need to apply an
# artificial limit if one wasn't provided
@@ -1134,35 +1413,38 @@ class MySQLCompiler(compiler.SQLCompiler):
# but also is consistent with the usage of the upper
# bound as part of MySQL's "syntax" for OFFSET with
# no LIMIT
- return ' \n LIMIT %s, %s' % (
+ return " \n LIMIT %s, %s" % (
self.process(offset_clause, **kw),
- "18446744073709551615")
+ "18446744073709551615",
+ )
else:
- return ' \n LIMIT %s, %s' % (
+ return " \n LIMIT %s, %s" % (
self.process(offset_clause, **kw),
- self.process(limit_clause, **kw))
+ self.process(limit_clause, **kw),
+ )
else:
# No offset provided, so just use the limit
- return ' \n LIMIT %s' % (self.process(limit_clause, **kw),)
+ return " \n LIMIT %s" % (self.process(limit_clause, **kw),)
def update_limit_clause(self, update_stmt):
- limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None)
+ limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
if limit:
return "LIMIT %s" % limit
else:
return None
- def update_tables_clause(self, update_stmt, from_table,
- extra_froms, **kw):
- return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw)
- for t in [from_table] + list(extra_froms))
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ return ", ".join(
+ t._compiler_dispatch(self, asfrom=True, **kw)
+ for t in [from_table] + list(extra_froms)
+ )
- def update_from_clause(self, update_stmt, from_table,
- extra_froms, from_hints, **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
return None
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
@@ -1171,24 +1453,27 @@ class MySQLCompiler(compiler.SQLCompiler):
self, asfrom=True, iscrud=True, ashint=ashint
)
- def delete_extra_from_clause(self, delete_stmt, from_table,
- extra_froms, from_hints, **kw):
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Render the DELETE .. USING clause specific to MySQL."""
- return "USING " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in [from_table] + extra_froms)
+ return "USING " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
def visit_empty_set_expr(self, element_types):
return (
"SELECT %(outer)s FROM (SELECT %(inner)s) "
- "as _empty_set WHERE 1!=1" % {
+ "as _empty_set WHERE 1!=1"
+ % {
"inner": ", ".join(
"1 AS _in_%s" % idx
- for idx, type_ in enumerate(element_types)),
+ for idx, type_ in enumerate(element_types)
+ ),
"outer": ", ".join(
- "_in_%s" % idx
- for idx, type_ in enumerate(element_types))
+ "_in_%s" % idx for idx, type_ in enumerate(element_types)
+ ),
}
)
@@ -1200,35 +1485,39 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
colspec = [
self.preparer.format_column(column),
self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ column.type, type_expression=column
+ ),
]
is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP)
if not column.nullable:
- colspec.append('NOT NULL')
+ colspec.append("NOT NULL")
# see: http://docs.sqlalchemy.org/en/latest/dialects/
# mysql.html#mysql_timestamp_null
elif column.nullable and is_timestamp:
- colspec.append('NULL')
+ colspec.append("NULL")
default = self.get_column_default_string(column)
if default is not None:
- colspec.append('DEFAULT ' + default)
+ colspec.append("DEFAULT " + default)
comment = column.comment
if comment is not None:
literal = self.sql_compiler.render_literal_value(
- comment, sqltypes.String())
- colspec.append('COMMENT ' + literal)
+ comment, sqltypes.String()
+ )
+ colspec.append("COMMENT " + literal)
- if column.table is not None \
- and column is column.table._autoincrement_column and \
- column.server_default is None:
- colspec.append('AUTO_INCREMENT')
+ if (
+ column.table is not None
+ and column is column.table._autoincrement_column
+ and column.server_default is None
+ ):
+ colspec.append("AUTO_INCREMENT")
- return ' '.join(colspec)
+ return " ".join(colspec)
def post_create_table(self, table):
"""Build table-level CREATE options like ENGINE and COLLATE."""
@@ -1236,76 +1525,94 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
table_opts = []
opts = dict(
- (
- k[len(self.dialect.name) + 1:].upper(),
- v
- )
+ (k[len(self.dialect.name) + 1 :].upper(), v)
for k, v in table.kwargs.items()
- if k.startswith('%s_' % self.dialect.name)
+ if k.startswith("%s_" % self.dialect.name)
)
if table.comment is not None:
- opts['COMMENT'] = table.comment
+ opts["COMMENT"] = table.comment
partition_options = [
- 'PARTITION_BY', 'PARTITIONS', 'SUBPARTITIONS',
- 'SUBPARTITION_BY'
+ "PARTITION_BY",
+ "PARTITIONS",
+ "SUBPARTITIONS",
+ "SUBPARTITION_BY",
]
nonpart_options = set(opts).difference(partition_options)
part_options = set(opts).intersection(partition_options)
- for opt in topological.sort([
- ('DEFAULT_CHARSET', 'COLLATE'),
- ('DEFAULT_CHARACTER_SET', 'COLLATE'),
- ], nonpart_options):
+ for opt in topological.sort(
+ [
+ ("DEFAULT_CHARSET", "COLLATE"),
+ ("DEFAULT_CHARACTER_SET", "COLLATE"),
+ ],
+ nonpart_options,
+ ):
arg = opts[opt]
if opt in _reflection._options_of_type_string:
arg = self.sql_compiler.render_literal_value(
- arg, sqltypes.String())
-
- if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY',
- 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET',
- 'DEFAULT_CHARSET',
- 'DEFAULT_COLLATE'):
- opt = opt.replace('_', ' ')
+ arg, sqltypes.String()
+ )
- joiner = '='
- if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET',
- 'CHARACTER SET', 'COLLATE'):
- joiner = ' '
+ if opt in (
+ "DATA_DIRECTORY",
+ "INDEX_DIRECTORY",
+ "DEFAULT_CHARACTER_SET",
+ "CHARACTER_SET",
+ "DEFAULT_CHARSET",
+ "DEFAULT_COLLATE",
+ ):
+ opt = opt.replace("_", " ")
+
+ joiner = "="
+ if opt in (
+ "TABLESPACE",
+ "DEFAULT CHARACTER SET",
+ "CHARACTER SET",
+ "COLLATE",
+ ):
+ joiner = " "
table_opts.append(joiner.join((opt, arg)))
- for opt in topological.sort([
- ('PARTITION_BY', 'PARTITIONS'),
- ('PARTITION_BY', 'SUBPARTITION_BY'),
- ('PARTITION_BY', 'SUBPARTITIONS'),
- ('PARTITIONS', 'SUBPARTITIONS'),
- ('PARTITIONS', 'SUBPARTITION_BY'),
- ('SUBPARTITION_BY', 'SUBPARTITIONS')
- ], part_options):
+ for opt in topological.sort(
+ [
+ ("PARTITION_BY", "PARTITIONS"),
+ ("PARTITION_BY", "SUBPARTITION_BY"),
+ ("PARTITION_BY", "SUBPARTITIONS"),
+ ("PARTITIONS", "SUBPARTITIONS"),
+ ("PARTITIONS", "SUBPARTITION_BY"),
+ ("SUBPARTITION_BY", "SUBPARTITIONS"),
+ ],
+ part_options,
+ ):
arg = opts[opt]
if opt in _reflection._options_of_type_string:
arg = self.sql_compiler.render_literal_value(
- arg, sqltypes.String())
+ arg, sqltypes.String()
+ )
- opt = opt.replace('_', ' ')
- joiner = ' '
+ opt = opt.replace("_", " ")
+ joiner = " "
table_opts.append(joiner.join((opt, arg)))
- return ' '.join(table_opts)
+ return " ".join(table_opts)
def visit_create_index(self, create, **kw):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
table = preparer.format_table(index.table)
- columns = [self.sql_compiler.process(expr, include_table=False,
- literal_binds=True)
- for expr in index.expressions]
+ columns = [
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ]
name = self._prepared_index_name(index)
@@ -1313,53 +1620,54 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
if index.unique:
text += "UNIQUE "
- index_prefix = index.kwargs.get('mysql_prefix', None)
+ index_prefix = index.kwargs.get("mysql_prefix", None)
if index_prefix:
- text += index_prefix + ' '
+ text += index_prefix + " "
text += "INDEX %s ON %s " % (name, table)
- length = index.dialect_options['mysql']['length']
+ length = index.dialect_options["mysql"]["length"]
if length is not None:
if isinstance(length, dict):
# length value can be a (column_name --> integer value)
# mapping specifying the prefix length for each column of the
# index
- columns = ', '.join(
- '%s(%d)' % (expr, length[col.name]) if col.name in length
- else
- (
- '%s(%d)' % (expr, length[expr]) if expr in length
- else '%s' % expr
+ columns = ", ".join(
+ "%s(%d)" % (expr, length[col.name])
+ if col.name in length
+ else (
+ "%s(%d)" % (expr, length[expr])
+ if expr in length
+ else "%s" % expr
)
for col, expr in zip(index.expressions, columns)
)
else:
# or can be an integer value specifying the same
# prefix length for all columns of the index
- columns = ', '.join(
- '%s(%d)' % (col, length)
- for col in columns
+ columns = ", ".join(
+ "%s(%d)" % (col, length) for col in columns
)
else:
- columns = ', '.join(columns)
- text += '(%s)' % columns
+ columns = ", ".join(columns)
+ text += "(%s)" % columns
- parser = index.dialect_options['mysql']['with_parser']
+ parser = index.dialect_options["mysql"]["with_parser"]
if parser is not None:
- text += " WITH PARSER %s" % (parser, )
+ text += " WITH PARSER %s" % (parser,)
- using = index.dialect_options['mysql']['using']
+ using = index.dialect_options["mysql"]["using"]
if using is not None:
text += " USING %s" % (preparer.quote(using))
return text
def visit_primary_key_constraint(self, constraint):
- text = super(MySQLDDLCompiler, self).\
- visit_primary_key_constraint(constraint)
- using = constraint.dialect_options['mysql']['using']
+ text = super(MySQLDDLCompiler, self).visit_primary_key_constraint(
+ constraint
+ )
+ using = constraint.dialect_options["mysql"]["using"]
if using:
text += " USING %s" % (self.preparer.quote(using))
return text
@@ -1368,9 +1676,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
index = drop.element
return "\nDROP INDEX %s ON %s" % (
- self._prepared_index_name(index,
- include_schema=False),
- self.preparer.format_table(index.table))
+ self._prepared_index_name(index, include_schema=False),
+ self.preparer.format_table(index.table),
+ )
def visit_drop_constraint(self, drop):
constraint = drop.element
@@ -1386,29 +1694,33 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
else:
qual = ""
const = self.preparer.format_constraint(constraint)
- return "ALTER TABLE %s DROP %s%s" % \
- (self.preparer.format_table(constraint.table),
- qual, const)
+ return "ALTER TABLE %s DROP %s%s" % (
+ self.preparer.format_table(constraint.table),
+ qual,
+ const,
+ )
def define_constraint_match(self, constraint):
if constraint.match is not None:
raise exc.CompileError(
"MySQL ignores the 'MATCH' keyword while at the same time "
- "causes ON UPDATE/ON DELETE clauses to be ignored.")
+ "causes ON UPDATE/ON DELETE clauses to be ignored."
+ )
return ""
def visit_set_table_comment(self, create):
return "ALTER TABLE %s COMMENT %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_set_column_comment(self, create):
return "ALTER TABLE %s CHANGE %s %s" % (
self.preparer.format_table(create.element.table),
self.preparer.format_column(create.element),
- self.get_column_specification(create.element)
+ self.get_column_specification(create.element),
)
@@ -1420,9 +1732,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
return spec
if type_.unsigned:
- spec += ' UNSIGNED'
+ spec += " UNSIGNED"
if type_.zerofill:
- spec += ' ZEROFILL'
+ spec += " ZEROFILL"
return spec
def _extend_string(self, type_, defaults, spec):
@@ -1434,28 +1746,30 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def attr(name):
return getattr(type_, name, defaults.get(name))
- if attr('charset'):
- charset = 'CHARACTER SET %s' % attr('charset')
- elif attr('ascii'):
- charset = 'ASCII'
- elif attr('unicode'):
- charset = 'UNICODE'
+ if attr("charset"):
+ charset = "CHARACTER SET %s" % attr("charset")
+ elif attr("ascii"):
+ charset = "ASCII"
+ elif attr("unicode"):
+ charset = "UNICODE"
else:
charset = None
- if attr('collation'):
- collation = 'COLLATE %s' % type_.collation
- elif attr('binary'):
- collation = 'BINARY'
+ if attr("collation"):
+ collation = "COLLATE %s" % type_.collation
+ elif attr("binary"):
+ collation = "BINARY"
else:
collation = None
- if attr('national'):
+ if attr("national"):
# NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
- return ' '.join([c for c in ('NATIONAL', spec, collation)
- if c is not None])
- return ' '.join([c for c in (spec, charset, collation)
- if c is not None])
+ return " ".join(
+ [c for c in ("NATIONAL", spec, collation) if c is not None]
+ )
+ return " ".join(
+ [c for c in (spec, charset, collation) if c is not None]
+ )
def _mysql_type(self, type_):
return isinstance(type_, (_StringType, _NumericType))
@@ -1464,95 +1778,113 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
if type_.precision is None:
return self._extend_numeric(type_, "NUMERIC")
elif type_.scale is None:
- return self._extend_numeric(type_,
- "NUMERIC(%(precision)s)" %
- {'precision': type_.precision})
+ return self._extend_numeric(
+ type_,
+ "NUMERIC(%(precision)s)" % {"precision": type_.precision},
+ )
else:
- return self._extend_numeric(type_,
- "NUMERIC(%(precision)s, %(scale)s)" %
- {'precision': type_.precision,
- 'scale': type_.scale})
+ return self._extend_numeric(
+ type_,
+ "NUMERIC(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return self._extend_numeric(type_, "DECIMAL")
elif type_.scale is None:
- return self._extend_numeric(type_,
- "DECIMAL(%(precision)s)" %
- {'precision': type_.precision})
+ return self._extend_numeric(
+ type_,
+ "DECIMAL(%(precision)s)" % {"precision": type_.precision},
+ )
else:
- return self._extend_numeric(type_,
- "DECIMAL(%(precision)s, %(scale)s)" %
- {'precision': type_.precision,
- 'scale': type_.scale})
+ return self._extend_numeric(
+ type_,
+ "DECIMAL(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
def visit_DOUBLE(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
- return self._extend_numeric(type_,
- "DOUBLE(%(precision)s, %(scale)s)" %
- {'precision': type_.precision,
- 'scale': type_.scale})
+ return self._extend_numeric(
+ type_,
+ "DOUBLE(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
else:
- return self._extend_numeric(type_, 'DOUBLE')
+ return self._extend_numeric(type_, "DOUBLE")
def visit_REAL(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
- return self._extend_numeric(type_,
- "REAL(%(precision)s, %(scale)s)" %
- {'precision': type_.precision,
- 'scale': type_.scale})
+ return self._extend_numeric(
+ type_,
+ "REAL(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
else:
- return self._extend_numeric(type_, 'REAL')
+ return self._extend_numeric(type_, "REAL")
def visit_FLOAT(self, type_, **kw):
- if self._mysql_type(type_) and \
- type_.scale is not None and \
- type_.precision is not None:
+ if (
+ self._mysql_type(type_)
+ and type_.scale is not None
+ and type_.precision is not None
+ ):
return self._extend_numeric(
- type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale))
+ type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)
+ )
elif type_.precision is not None:
- return self._extend_numeric(type_,
- "FLOAT(%s)" % (type_.precision,))
+ return self._extend_numeric(
+ type_, "FLOAT(%s)" % (type_.precision,)
+ )
else:
return self._extend_numeric(type_, "FLOAT")
def visit_INTEGER(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
- type_, "INTEGER(%(display_width)s)" %
- {'display_width': type_.display_width})
+ type_,
+ "INTEGER(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
else:
return self._extend_numeric(type_, "INTEGER")
def visit_BIGINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
- type_, "BIGINT(%(display_width)s)" %
- {'display_width': type_.display_width})
+ type_,
+ "BIGINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
else:
return self._extend_numeric(type_, "BIGINT")
def visit_MEDIUMINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
- type_, "MEDIUMINT(%(display_width)s)" %
- {'display_width': type_.display_width})
+ type_,
+ "MEDIUMINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
else:
return self._extend_numeric(type_, "MEDIUMINT")
def visit_TINYINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
- return self._extend_numeric(type_,
- "TINYINT(%s)" % type_.display_width)
+ return self._extend_numeric(
+ type_, "TINYINT(%s)" % type_.display_width
+ )
else:
return self._extend_numeric(type_, "TINYINT")
def visit_SMALLINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
- return self._extend_numeric(type_,
- "SMALLINT(%(display_width)s)" %
- {'display_width': type_.display_width}
- )
+ return self._extend_numeric(
+ type_,
+ "SMALLINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
else:
return self._extend_numeric(type_, "SMALLINT")
@@ -1563,7 +1895,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
return "BIT"
def visit_DATETIME(self, type_, **kw):
- if getattr(type_, 'fsp', None):
+ if getattr(type_, "fsp", None):
return "DATETIME(%d)" % type_.fsp
else:
return "DATETIME"
@@ -1572,13 +1904,13 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
return "DATE"
def visit_TIME(self, type_, **kw):
- if getattr(type_, 'fsp', None):
+ if getattr(type_, "fsp", None):
return "TIME(%d)" % type_.fsp
else:
return "TIME"
def visit_TIMESTAMP(self, type_, **kw):
- if getattr(type_, 'fsp', None):
+ if getattr(type_, "fsp", None):
return "TIMESTAMP(%d)" % type_.fsp
else:
return "TIMESTAMP"
@@ -1606,17 +1938,17 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def visit_VARCHAR(self, type_, **kw):
if type_.length:
- return self._extend_string(
- type_, {}, "VARCHAR(%d)" % type_.length)
+ return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
else:
raise exc.CompileError(
- "VARCHAR requires a length on dialect %s" %
- self.dialect.name)
+ "VARCHAR requires a length on dialect %s" % self.dialect.name
+ )
def visit_CHAR(self, type_, **kw):
if type_.length:
- return self._extend_string(type_, {}, "CHAR(%(length)s)" %
- {'length': type_.length})
+ return self._extend_string(
+ type_, {}, "CHAR(%(length)s)" % {"length": type_.length}
+ )
else:
return self._extend_string(type_, {}, "CHAR")
@@ -1625,22 +1957,26 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
# of "NVARCHAR".
if type_.length:
return self._extend_string(
- type_, {'national': True},
- "VARCHAR(%(length)s)" % {'length': type_.length})
+ type_,
+ {"national": True},
+ "VARCHAR(%(length)s)" % {"length": type_.length},
+ )
else:
raise exc.CompileError(
- "NVARCHAR requires a length on dialect %s" %
- self.dialect.name)
+ "NVARCHAR requires a length on dialect %s" % self.dialect.name
+ )
def visit_NCHAR(self, type_, **kw):
# We'll actually generate the equiv.
# "NATIONAL CHAR" instead of "NCHAR".
if type_.length:
return self._extend_string(
- type_, {'national': True},
- "CHAR(%(length)s)" % {'length': type_.length})
+ type_,
+ {"national": True},
+ "CHAR(%(length)s)" % {"length": type_.length},
+ )
else:
- return self._extend_string(type_, {'national': True}, "CHAR")
+ return self._extend_string(type_, {"national": True}, "CHAR")
def visit_VARBINARY(self, type_, **kw):
return "VARBINARY(%d)" % type_.length
@@ -1676,17 +2012,19 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
quoted_enums = []
for e in enumerated_values:
quoted_enums.append("'%s'" % e.replace("'", "''"))
- return self._extend_string(type_, {}, "%s(%s)" % (
- name, ",".join(quoted_enums))
+ return self._extend_string(
+ type_, {}, "%s(%s)" % (name, ",".join(quoted_enums))
)
def visit_ENUM(self, type_, **kw):
- return self._visit_enumerated_values("ENUM", type_,
- type_._enumerated_values)
+ return self._visit_enumerated_values(
+ "ENUM", type_, type_._enumerated_values
+ )
def visit_SET(self, type_, **kw):
- return self._visit_enumerated_values("SET", type_,
- type_._enumerated_values)
+ return self._visit_enumerated_values(
+ "SET", type_, type_._enumerated_values
+ )
def visit_BOOLEAN(self, type, **kw):
return "BOOL"
@@ -1703,9 +2041,8 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
quote = '"'
super(MySQLIdentifierPreparer, self).__init__(
- dialect,
- initial_quote=quote,
- escape_quote=quote)
+ dialect, initial_quote=quote, escape_quote=quote
+ )
def _quote_free_identifiers(self, *ids):
"""Unilaterally identifier-quote any number of strings."""
@@ -1719,7 +2056,7 @@ class MySQLDialect(default.DefaultDialect):
Not used directly in application code.
"""
- name = 'mysql'
+ name = "mysql"
supports_alter = True
# MySQL has no true "boolean" type; we
@@ -1738,7 +2075,7 @@ class MySQLDialect(default.DefaultDialect):
supports_comments = True
inline_comments = True
- default_paramstyle = 'format'
+ default_paramstyle = "format"
colspecs = colspecs
cte_follows_insert = True
@@ -1756,26 +2093,28 @@ class MySQLDialect(default.DefaultDialect):
_server_ansiquotes = False
construct_arguments = [
- (sa_schema.Table, {
- "*": None
- }),
- (sql.Update, {
- "limit": None
- }),
- (sa_schema.PrimaryKeyConstraint, {
- "using": None
- }),
- (sa_schema.Index, {
- "using": None,
- "length": None,
- "prefix": None,
- "with_parser": None
- })
+ (sa_schema.Table, {"*": None}),
+ (sql.Update, {"limit": None}),
+ (sa_schema.PrimaryKeyConstraint, {"using": None}),
+ (
+ sa_schema.Index,
+ {
+ "using": None,
+ "length": None,
+ "prefix": None,
+ "with_parser": None,
+ },
+ ),
]
- def __init__(self, isolation_level=None, json_serializer=None,
- json_deserializer=None, **kwargs):
- kwargs.pop('use_ansiquotes', None) # legacy
+ def __init__(
+ self,
+ isolation_level=None,
+ json_serializer=None,
+ json_deserializer=None,
+ **kwargs
+ ):
+ kwargs.pop("use_ansiquotes", None) # legacy
default.DefaultDialect.__init__(self, **kwargs)
self.isolation_level = isolation_level
self._json_serializer = json_serializer
@@ -1783,22 +2122,30 @@ class MySQLDialect(default.DefaultDialect):
def on_connect(self):
if self.isolation_level is not None:
+
def connect(conn):
self.set_isolation_level(conn, self.isolation_level)
+
return connect
else:
return None
- _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
- 'READ COMMITTED', 'REPEATABLE READ'])
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ ]
+ )
def set_isolation_level(self, connection, level):
- level = level.replace('_', ' ')
+ level = level.replace("_", " ")
# adjust for ConnectionFairy being present
# allows attribute set e.g. "connection.autocommit = True"
# to work properly
- if hasattr(connection, 'connection'):
+ if hasattr(connection, "connection"):
connection = connection.connection
self._set_isolation_level(connection, level)
@@ -1807,8 +2154,8 @@ class MySQLDialect(default.DefaultDialect):
if level not in self._isolation_lookup:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
- "Valid isolation levels for %s are %s" %
- (level, self.name, ", ".join(self._isolation_lookup))
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
)
cursor = connection.cursor()
cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level)
@@ -1818,9 +2165,9 @@ class MySQLDialect(default.DefaultDialect):
def get_isolation_level(self, connection):
cursor = connection.cursor()
if self._is_mysql and self.server_version_info >= (5, 7, 20):
- cursor.execute('SELECT @@transaction_isolation')
+ cursor.execute("SELECT @@transaction_isolation")
else:
- cursor.execute('SELECT @@tx_isolation')
+ cursor.execute("SELECT @@tx_isolation")
val = cursor.fetchone()[0]
cursor.close()
if util.py3k and isinstance(val, bytes):
@@ -1840,7 +2187,7 @@ class MySQLDialect(default.DefaultDialect):
val = val.decode()
version = []
- r = re.compile(r'[.\-]')
+ r = re.compile(r"[.\-]")
for n in r.split(val):
try:
version.append(int(n))
@@ -1885,29 +2232,38 @@ class MySQLDialect(default.DefaultDialect):
connection.execute(sql.text("XA END :xid"), xid=xid)
connection.execute(sql.text("XA PREPARE :xid"), xid=xid)
- def do_rollback_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if not is_prepared:
connection.execute(sql.text("XA END :xid"), xid=xid)
connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid)
- def do_commit_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
connection.execute(sql.text("XA COMMIT :xid"), xid=xid)
def do_recover_twophase(self, connection):
resultset = connection.execute("XA RECOVER")
- return [row['data'][0:row['gtrid_length']] for row in resultset]
+ return [row["data"][0 : row["gtrid_length"]] for row in resultset]
def is_disconnect(self, e, connection, cursor):
- if isinstance(e, (self.dbapi.OperationalError,
- self.dbapi.ProgrammingError)):
- return self._extract_error_code(e) in \
- (2006, 2013, 2014, 2045, 2055)
+ if isinstance(
+ e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
+ ):
+ return self._extract_error_code(e) in (
+ 2006,
+ 2013,
+ 2014,
+ 2045,
+ 2055,
+ )
elif isinstance(
- e, (self.dbapi.InterfaceError, self.dbapi.InternalError)):
+ e, (self.dbapi.InterfaceError, self.dbapi.InternalError)
+ ):
# if underlying connection is closed,
# this is the error you get
return "(0, '')" in str(e)
@@ -1944,7 +2300,7 @@ class MySQLDialect(default.DefaultDialect):
raise NotImplementedError()
def _get_default_schema_name(self, connection):
- return connection.execute('SELECT DATABASE()').scalar()
+ return connection.execute("SELECT DATABASE()").scalar()
def has_table(self, connection, table_name, schema=None):
# SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly
@@ -1957,15 +2313,19 @@ class MySQLDialect(default.DefaultDialect):
# full_name = self.identifier_preparer.format_table(table,
# use_schema=True)
- full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
- schema, table_name))
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(
+ schema, table_name
+ )
+ )
st = "DESCRIBE %s" % full_name
rs = None
try:
try:
rs = connection.execution_options(
- skip_user_error_events=True).execute(st)
+ skip_user_error_events=True
+ ).execute(st)
have = rs.fetchone() is not None
rs.close()
return have
@@ -1986,12 +2346,13 @@ class MySQLDialect(default.DefaultDialect):
# if ansiquotes == True, build a new IdentifierPreparer
# with the new setting
self.identifier_preparer = self.preparer(
- self, server_ansiquotes=self._server_ansiquotes)
+ self, server_ansiquotes=self._server_ansiquotes
+ )
default.DefaultDialect.initialize(self, connection)
self._needs_correct_for_88718 = (
- not self._is_mariadb and self.server_version_info >= (8, )
+ not self._is_mariadb and self.server_version_info >= (8,)
)
self._warn_for_known_db_issues()
@@ -2007,20 +2368,23 @@ class MySQLDialect(default.DefaultDialect):
"additional issue prevents proper migrations of columns "
"with CHECK constraints (MDEV-11114). Please upgrade to "
"MariaDB 10.2.9 or greater, or use the MariaDB 10.1 "
- "series, to avoid these issues." % (mdb_version, ))
+ "series, to avoid these issues." % (mdb_version,)
+ )
@property
def _is_mariadb(self):
- return 'MariaDB' in self.server_version_info
+ return "MariaDB" in self.server_version_info
@property
def _is_mysql(self):
- return 'MariaDB' not in self.server_version_info
+ return "MariaDB" not in self.server_version_info
@property
def _is_mariadb_102(self):
- return self._is_mariadb and \
- self._mariadb_normalized_version_info > (10, 2)
+ return self._is_mariadb and self._mariadb_normalized_version_info > (
+ 10,
+ 2,
+ )
@property
def _mariadb_normalized_version_info(self):
@@ -2028,15 +2392,17 @@ class MySQLDialect(default.DefaultDialect):
# the string "5.5"; now that we use @@version we no longer see this.
if self._is_mariadb:
- idx = self.server_version_info.index('MariaDB')
- return self.server_version_info[idx - 3: idx]
+ idx = self.server_version_info.index("MariaDB")
+ return self.server_version_info[idx - 3 : idx]
else:
return self.server_version_info
@property
def _supports_cast(self):
- return self.server_version_info is None or \
- self.server_version_info >= (4, 0, 2)
+ return (
+ self.server_version_info is None
+ or self.server_version_info >= (4, 0, 2)
+ )
@reflection.cache
def get_schema_names(self, connection, **kw):
@@ -2054,18 +2420,23 @@ class MySQLDialect(default.DefaultDialect):
charset = self._connection_charset
if self.server_version_info < (5, 0, 2):
rp = connection.execute(
- "SHOW TABLES FROM %s" %
- self.identifier_preparer.quote_identifier(current_schema))
- return [row[0] for
- row in self._compat_fetchall(rp, charset=charset)]
+ "SHOW TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(current_schema)
+ )
+ return [
+ row[0] for row in self._compat_fetchall(rp, charset=charset)
+ ]
else:
rp = connection.execute(
- "SHOW FULL TABLES FROM %s" %
- self.identifier_preparer.quote_identifier(current_schema))
+ "SHOW FULL TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(current_schema)
+ )
- return [row[0]
- for row in self._compat_fetchall(rp, charset=charset)
- if row[1] == 'BASE TABLE']
+ return [
+ row[0]
+ for row in self._compat_fetchall(rp, charset=charset)
+ if row[1] == "BASE TABLE"
+ ]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
@@ -2077,72 +2448,77 @@ class MySQLDialect(default.DefaultDialect):
return self.get_table_names(connection, schema)
charset = self._connection_charset
rp = connection.execute(
- "SHOW FULL TABLES FROM %s" %
- self.identifier_preparer.quote_identifier(schema))
- return [row[0]
- for row in self._compat_fetchall(rp, charset=charset)
- if row[1] in ('VIEW', 'SYSTEM VIEW')]
+ "SHOW FULL TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(schema)
+ )
+ return [
+ row[0]
+ for row in self._compat_fetchall(rp, charset=charset)
+ if row[1] in ("VIEW", "SYSTEM VIEW")
+ ]
@reflection.cache
def get_table_options(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
return parsed_state.table_options
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
return parsed_state.columns
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
for key in parsed_state.keys:
- if key['type'] == 'PRIMARY':
+ if key["type"] == "PRIMARY":
# There can be only one.
- cols = [s[0] for s in key['columns']]
- return {'constrained_columns': cols, 'name': None}
- return {'constrained_columns': [], 'name': None}
+ cols = [s[0] for s in key["columns"]]
+ return {"constrained_columns": cols, "name": None}
+ return {"constrained_columns": [], "name": None}
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
default_schema = None
fkeys = []
for spec in parsed_state.fk_constraints:
- ref_name = spec['table'][-1]
- ref_schema = len(spec['table']) > 1 and \
- spec['table'][-2] or schema
+ ref_name = spec["table"][-1]
+ ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema
if not ref_schema:
if default_schema is None:
- default_schema = \
- connection.dialect.default_schema_name
+ default_schema = connection.dialect.default_schema_name
if schema == default_schema:
ref_schema = schema
- loc_names = spec['local']
- ref_names = spec['foreign']
+ loc_names = spec["local"]
+ ref_names = spec["foreign"]
con_kw = {}
- for opt in ('onupdate', 'ondelete'):
+ for opt in ("onupdate", "ondelete"):
if spec.get(opt, False):
con_kw[opt] = spec[opt]
fkey_d = {
- 'name': spec['name'],
- 'constrained_columns': loc_names,
- 'referred_schema': ref_schema,
- 'referred_table': ref_name,
- 'referred_columns': ref_names,
- 'options': con_kw
+ "name": spec["name"],
+ "constrained_columns": loc_names,
+ "referred_schema": ref_schema,
+ "referred_table": ref_name,
+ "referred_columns": ref_names,
+ "options": con_kw,
}
fkeys.append(fkey_d)
@@ -2172,25 +2548,26 @@ class MySQLDialect(default.DefaultDialect):
default_schema_name = connection.dialect.default_schema_name
col_tuples = [
(
- lower(rec['referred_schema'] or default_schema_name),
- lower(rec['referred_table']),
- col_name
+ lower(rec["referred_schema"] or default_schema_name),
+ lower(rec["referred_table"]),
+ col_name,
)
for rec in fkeys
- for col_name in rec['referred_columns']
+ for col_name in rec["referred_columns"]
]
if col_tuples:
correct_for_wrong_fk_case = connection.execute(
- sql.text("""
+ sql.text(
+ """
select table_schema, table_name, column_name
from information_schema.columns
where (table_schema, table_name, lower(column_name)) in
:table_data;
- """).bindparams(
- sql.bindparam("table_data", expanding=True)
- ), table_data=col_tuples
+ """
+ ).bindparams(sql.bindparam("table_data", expanding=True)),
+ table_data=col_tuples,
)
# in casing=0, table name and schema name come back in their
@@ -2208,109 +2585,117 @@ class MySQLDialect(default.DefaultDialect):
d[(lower(schema), lower(tname))][cname.lower()] = cname
for fkey in fkeys:
- fkey['referred_columns'] = [
+ fkey["referred_columns"] = [
d[
(
lower(
- fkey['referred_schema'] or
- default_schema_name),
- lower(fkey['referred_table'])
+ fkey["referred_schema"] or default_schema_name
+ ),
+ lower(fkey["referred_table"]),
)
][col.lower()]
- for col in fkey['referred_columns']
+ for col in fkey["referred_columns"]
]
@reflection.cache
- def get_check_constraints(
- self, connection, table_name, schema=None, **kw):
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
return [
- {"name": spec['name'], "sqltext": spec['sqltext']}
+ {"name": spec["name"], "sqltext": spec["sqltext"]}
for spec in parsed_state.ck_constraints
]
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
- return {"text": parsed_state.table_options.get('mysql_comment', None)}
+ connection, table_name, schema, **kw
+ )
+ return {"text": parsed_state.table_options.get("mysql_comment", None)}
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
indexes = []
for spec in parsed_state.keys:
dialect_options = {}
unique = False
- flavor = spec['type']
- if flavor == 'PRIMARY':
+ flavor = spec["type"]
+ if flavor == "PRIMARY":
continue
- if flavor == 'UNIQUE':
+ if flavor == "UNIQUE":
unique = True
- elif flavor in ('FULLTEXT', 'SPATIAL'):
+ elif flavor in ("FULLTEXT", "SPATIAL"):
dialect_options["mysql_prefix"] = flavor
elif flavor is None:
pass
else:
self.logger.info(
- "Converting unknown KEY type %s to a plain KEY", flavor)
+ "Converting unknown KEY type %s to a plain KEY", flavor
+ )
pass
- if spec['parser']:
- dialect_options['mysql_with_parser'] = spec['parser']
+ if spec["parser"]:
+ dialect_options["mysql_with_parser"] = spec["parser"]
index_d = {}
if dialect_options:
index_d["dialect_options"] = dialect_options
- index_d['name'] = spec['name']
- index_d['column_names'] = [s[0] for s in spec['columns']]
- index_d['unique'] = unique
+ index_d["name"] = spec["name"]
+ index_d["column_names"] = [s[0] for s in spec["columns"]]
+ index_d["unique"] = unique
if flavor:
- index_d['type'] = flavor
+ index_d["type"] = flavor
indexes.append(index_d)
return indexes
@reflection.cache
- def get_unique_constraints(self, connection, table_name,
- schema=None, **kw):
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
parsed_state = self._parsed_state_or_create(
- connection, table_name, schema, **kw)
+ connection, table_name, schema, **kw
+ )
return [
{
- 'name': key['name'],
- 'column_names': [col[0] for col in key['columns']],
- 'duplicates_index': key['name'],
+ "name": key["name"],
+ "column_names": [col[0] for col in key["columns"]],
+ "duplicates_index": key["name"],
}
for key in parsed_state.keys
- if key['type'] == 'UNIQUE'
+ if key["type"] == "UNIQUE"
]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
charset = self._connection_charset
- full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
- schema, view_name))
- sql = self._show_create_table(connection, None, charset,
- full_name=full_name)
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(schema, view_name)
+ )
+ sql = self._show_create_table(
+ connection, None, charset, full_name=full_name
+ )
return sql
- def _parsed_state_or_create(self, connection, table_name,
- schema=None, **kw):
+ def _parsed_state_or_create(
+ self, connection, table_name, schema=None, **kw
+ ):
return self._setup_parser(
connection,
table_name,
schema,
- info_cache=kw.get('info_cache', None)
+ info_cache=kw.get("info_cache", None),
)
@util.memoized_property
@@ -2321,7 +2706,7 @@ class MySQLDialect(default.DefaultDialect):
retrieved server version information first.
"""
- if (self.server_version_info < (4, 1) and self._server_ansiquotes):
+ if self.server_version_info < (4, 1) and self._server_ansiquotes:
# ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
preparer = self.preparer(self, server_ansiquotes=False)
else:
@@ -2332,14 +2717,19 @@ class MySQLDialect(default.DefaultDialect):
def _setup_parser(self, connection, table_name, schema=None, **kw):
charset = self._connection_charset
parser = self._tabledef_parser
- full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
- schema, table_name))
- sql = self._show_create_table(connection, None, charset,
- full_name=full_name)
- if re.match(r'^CREATE (?:ALGORITHM)?.* VIEW', sql):
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(
+ schema, table_name
+ )
+ )
+ sql = self._show_create_table(
+ connection, None, charset, full_name=full_name
+ )
+ if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql):
# Adapt views to something table-like.
- columns = self._describe_table(connection, None, charset,
- full_name=full_name)
+ columns = self._describe_table(
+ connection, None, charset, full_name=full_name
+ )
sql = parser._describe_to_create(table_name, columns)
return parser.parse(sql, charset)
@@ -2356,17 +2746,18 @@ class MySQLDialect(default.DefaultDialect):
# http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
charset = self._connection_charset
- row = self._compat_first(connection.execute(
- "SHOW VARIABLES LIKE 'lower_case_table_names'"),
- charset=charset)
+ row = self._compat_first(
+ connection.execute("SHOW VARIABLES LIKE 'lower_case_table_names'"),
+ charset=charset,
+ )
if not row:
cs = 0
else:
# 4.0.15 returns OFF or ON according to [ticket:489]
# 3.23 doesn't, 4.0.27 doesn't..
- if row[1] == 'OFF':
+ if row[1] == "OFF":
cs = 0
- elif row[1] == 'ON':
+ elif row[1] == "ON":
cs = 1
else:
cs = int(row[1])
@@ -2384,7 +2775,7 @@ class MySQLDialect(default.DefaultDialect):
pass
else:
charset = self._connection_charset
- rs = connection.execute('SHOW COLLATION')
+ rs = connection.execute("SHOW COLLATION")
for row in self._compat_fetchall(rs, charset):
collations[row[0]] = row[1]
return collations
@@ -2392,33 +2783,36 @@ class MySQLDialect(default.DefaultDialect):
def _detect_sql_mode(self, connection):
row = self._compat_first(
connection.execute("SHOW VARIABLES LIKE 'sql_mode'"),
- charset=self._connection_charset)
+ charset=self._connection_charset,
+ )
if not row:
util.warn(
"Could not retrieve SQL_MODE; please ensure the "
- "MySQL user has permissions to SHOW VARIABLES")
- self._sql_mode = ''
+ "MySQL user has permissions to SHOW VARIABLES"
+ )
+ self._sql_mode = ""
else:
- self._sql_mode = row[1] or ''
+ self._sql_mode = row[1] or ""
def _detect_ansiquotes(self, connection):
"""Detect and adjust for the ANSI_QUOTES sql mode."""
mode = self._sql_mode
if not mode:
- mode = ''
+ mode = ""
elif mode.isdigit():
mode_no = int(mode)
- mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or ''
+ mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or ""
- self._server_ansiquotes = 'ANSI_QUOTES' in mode
+ self._server_ansiquotes = "ANSI_QUOTES" in mode
# as of MySQL 5.0.1
- self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode
+ self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode
- def _show_create_table(self, connection, table, charset=None,
- full_name=None):
+ def _show_create_table(
+ self, connection, table, charset=None, full_name=None
+ ):
"""Run SHOW CREATE TABLE for a ``Table``."""
if full_name is None:
@@ -2428,7 +2822,8 @@ class MySQLDialect(default.DefaultDialect):
rp = None
try:
rp = connection.execution_options(
- skip_user_error_events=True).execute(st)
+ skip_user_error_events=True
+ ).execute(st)
except exc.DBAPIError as e:
if self._extract_error_code(e.orig) == 1146:
raise exc.NoSuchTableError(full_name)
@@ -2441,8 +2836,7 @@ class MySQLDialect(default.DefaultDialect):
return sql
- def _describe_table(self, connection, table, charset=None,
- full_name=None):
+ def _describe_table(self, connection, table, charset=None, full_name=None):
"""Run DESCRIBE for a ``Table`` and return processed rows."""
if full_name is None:
@@ -2453,7 +2847,8 @@ class MySQLDialect(default.DefaultDialect):
try:
try:
rp = connection.execution_options(
- skip_user_error_events=True).execute(st)
+ skip_user_error_events=True
+ ).execute(st)
except exc.DBAPIError as e:
code = self._extract_error_code(e.orig)
if code == 1146:
@@ -2486,11 +2881,11 @@ class _DecodingRowProxy(object):
# seem to come up in DDL queries.
_encoding_compat = {
- 'koi8r': 'koi8_r',
- 'koi8u': 'koi8_u',
- 'utf16': 'utf-16-be', # MySQL's uft16 is always bigendian
- 'utf8mb4': 'utf8', # real utf8
- 'eucjpms': 'ujis',
+ "koi8r": "koi8_r",
+ "koi8u": "koi8_u",
+ "utf16": "utf-16-be", # MySQL's uft16 is always bigendian
+ "utf8mb4": "utf8", # real utf8
+ "eucjpms": "ujis",
}
def __init__(self, rowproxy, charset):
diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py
index d14290594..8a60608db 100644
--- a/lib/sqlalchemy/dialects/mysql/cymysql.py
+++ b/lib/sqlalchemy/dialects/mysql/cymysql.py
@@ -18,7 +18,7 @@
import re
from .mysqldb import MySQLDialect_mysqldb
-from .base import (BIT, MySQLDialect)
+from .base import BIT, MySQLDialect
from ... import util
@@ -34,27 +34,23 @@ class _cymysqlBIT(BIT):
v = v << 8 | i
return v
return value
+
return process
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
- driver = 'cymysql'
+ driver = "cymysql"
description_encoding = None
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_unicode_statements = True
- colspecs = util.update_copy(
- MySQLDialect.colspecs,
- {
- BIT: _cymysqlBIT,
- }
- )
+ colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
@classmethod
def dbapi(cls):
- return __import__('cymysql')
+ return __import__("cymysql")
def _detect_charset(self, connection):
return connection.connection.charset
@@ -64,8 +60,13 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError):
- return self._extract_error_code(e) in \
- (2006, 2013, 2014, 2045, 2055)
+ return self._extract_error_code(e) in (
+ 2006,
+ 2013,
+ 2014,
+ 2045,
+ 2055,
+ )
elif isinstance(e, self.dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
@@ -73,4 +74,5 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
else:
return False
+
dialect = MySQLDialect_cymysql
diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py
index 130ef2347..5d59b2073 100644
--- a/lib/sqlalchemy/dialects/mysql/dml.py
+++ b/lib/sqlalchemy/dialects/mysql/dml.py
@@ -6,7 +6,7 @@ from ...sql.base import _generative
from ... import exc
from ... import util
-__all__ = ('Insert', 'insert')
+__all__ = ("Insert", "insert")
class Insert(StandardInsert):
@@ -39,7 +39,7 @@ class Insert(StandardInsert):
@util.memoized_property
def inserted_alias(self):
- return alias(self.table, name='inserted')
+ return alias(self.table, name="inserted")
@_generative
def on_duplicate_key_update(self, *args, **kw):
@@ -87,27 +87,29 @@ class Insert(StandardInsert):
"""
if args and kw:
raise exc.ArgumentError(
- "Can't pass kwargs and positional arguments simultaneously")
+ "Can't pass kwargs and positional arguments simultaneously"
+ )
if args:
if len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary or list of tuples "
- "is accepted positionally.")
+ "is accepted positionally."
+ )
values = args[0]
else:
values = kw
- inserted_alias = getattr(self, 'inserted_alias', None)
+ inserted_alias = getattr(self, "inserted_alias", None)
self._post_values_clause = OnDuplicateClause(inserted_alias, values)
return self
-insert = public_factory(Insert, '.dialects.mysql.insert')
+insert = public_factory(Insert, ".dialects.mysql.insert")
class OnDuplicateClause(ClauseElement):
- __visit_name__ = 'on_duplicate_key_update'
+ __visit_name__ = "on_duplicate_key_update"
_parameter_ordering = None
@@ -118,11 +120,12 @@ class OnDuplicateClause(ClauseElement):
# Update._proces_colparams(), however we don't look for a special flag
# in this case since we are not disambiguating from other use cases as
# we are in Update.values().
- if isinstance(update, list) and \
- (update and isinstance(update[0], tuple)):
+ if isinstance(update, list) and (
+ update and isinstance(update[0], tuple)
+ ):
self._parameter_ordering = [key for key, value in update]
update = dict(update)
if not update or not isinstance(update, dict):
- raise ValueError('update parameter must be a non-empty dictionary')
+ raise ValueError("update parameter must be a non-empty dictionary")
self.update = update
diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py
index f63d64e8f..9586eff3f 100644
--- a/lib/sqlalchemy/dialects/mysql/enumerated.py
+++ b/lib/sqlalchemy/dialects/mysql/enumerated.py
@@ -14,29 +14,30 @@ from ...sql import sqltypes
class _EnumeratedValues(_StringType):
def _init_values(self, values, kw):
- self.quoting = kw.pop('quoting', 'auto')
+ self.quoting = kw.pop("quoting", "auto")
- if self.quoting == 'auto' and len(values):
+ if self.quoting == "auto" and len(values):
# What quoting character are we using?
q = None
for e in values:
if len(e) == 0:
- self.quoting = 'unquoted'
+ self.quoting = "unquoted"
break
elif q is None:
q = e[0]
if len(e) == 1 or e[0] != q or e[-1] != q:
- self.quoting = 'unquoted'
+ self.quoting = "unquoted"
break
else:
- self.quoting = 'quoted'
+ self.quoting = "quoted"
- if self.quoting == 'quoted':
+ if self.quoting == "quoted":
util.warn_deprecated(
- 'Manually quoting %s value literals is deprecated. Supply '
- 'unquoted values and use the quoting= option in cases of '
- 'ambiguity.' % self.__class__.__name__)
+ "Manually quoting %s value literals is deprecated. Supply "
+ "unquoted values and use the quoting= option in cases of "
+ "ambiguity." % self.__class__.__name__
+ )
values = self._strip_values(values)
@@ -58,7 +59,7 @@ class _EnumeratedValues(_StringType):
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues):
"""MySQL ENUM type."""
- __visit_name__ = 'ENUM'
+ __visit_name__ = "ENUM"
native_enum = True
@@ -115,7 +116,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues):
"""
- kw.pop('strict', None)
+ kw.pop("strict", None)
self._enum_init(enums, kw)
_StringType.__init__(self, length=self.length, **kw)
@@ -145,13 +146,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues):
def __repr__(self):
return util.generic_repr(
- self, to_inspect=[ENUM, _StringType, sqltypes.Enum])
+ self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
+ )
class SET(_EnumeratedValues):
"""MySQL SET type."""
- __visit_name__ = 'SET'
+ __visit_name__ = "SET"
def __init__(self, *values, **kw):
"""Construct a SET.
@@ -216,45 +218,43 @@ class SET(_EnumeratedValues):
"""
- self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False)
+ self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False)
values, length = self._init_values(values, kw)
self.values = tuple(values)
- if not self.retrieve_as_bitwise and '' in values:
+ if not self.retrieve_as_bitwise and "" in values:
raise exc.ArgumentError(
"Can't use the blank value '' in a SET without "
- "setting retrieve_as_bitwise=True")
+ "setting retrieve_as_bitwise=True"
+ )
if self.retrieve_as_bitwise:
self._bitmap = dict(
- (value, 2 ** idx)
- for idx, value in enumerate(self.values)
+ (value, 2 ** idx) for idx, value in enumerate(self.values)
)
self._bitmap.update(
- (2 ** idx, value)
- for idx, value in enumerate(self.values)
+ (2 ** idx, value) for idx, value in enumerate(self.values)
)
- kw.setdefault('length', length)
+ kw.setdefault("length", length)
super(SET, self).__init__(**kw)
def column_expression(self, colexpr):
if self.retrieve_as_bitwise:
return sql.type_coerce(
- sql.type_coerce(colexpr, sqltypes.Integer) + 0,
- self
+ sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
)
else:
return colexpr
def result_processor(self, dialect, coltype):
if self.retrieve_as_bitwise:
+
def process(value):
if value is not None:
value = int(value)
- return set(
- util.map_bits(self._bitmap.__getitem__, value)
- )
+ return set(util.map_bits(self._bitmap.__getitem__, value))
else:
return None
+
else:
super_convert = super(SET, self).result_processor(dialect, coltype)
@@ -263,18 +263,20 @@ class SET(_EnumeratedValues):
# MySQLdb returns a string, let's parse
if super_convert:
value = super_convert(value)
- return set(re.findall(r'[^,]+', value))
+ return set(re.findall(r"[^,]+", value))
else:
# mysql-connector-python does a naive
# split(",") which throws in an empty string
if value is not None:
- value.discard('')
+ value.discard("")
return value
+
return process
def bind_processor(self, dialect):
super_convert = super(SET, self).bind_processor(dialect)
if self.retrieve_as_bitwise:
+
def process(value):
if value is None:
return None
@@ -288,24 +290,23 @@ class SET(_EnumeratedValues):
for v in value:
int_value |= self._bitmap[v]
return int_value
+
else:
def process(value):
# accept strings and int (actually bitflag) values directly
if value is not None and not isinstance(
- value, util.int_types + util.string_types):
+ value, util.int_types + util.string_types
+ ):
value = ",".join(value)
if super_convert:
return super_convert(value)
else:
return value
+
return process
def adapt(self, impltype, **kw):
- kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise
- return util.constructor_copy(
- self, impltype,
- *self.values,
- **kw
- )
+ kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
+ return util.constructor_copy(self, impltype, *self.values, **kw)
diff --git a/lib/sqlalchemy/dialects/mysql/gaerdbms.py b/lib/sqlalchemy/dialects/mysql/gaerdbms.py
index 806e4c874..117cd28a2 100644
--- a/lib/sqlalchemy/dialects/mysql/gaerdbms.py
+++ b/lib/sqlalchemy/dialects/mysql/gaerdbms.py
@@ -44,11 +44,10 @@ from sqlalchemy.util import warn_deprecated
def _is_dev_environment():
- return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/')
+ return os.environ.get("SERVER_SOFTWARE", "").startswith("Development/")
class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
-
@classmethod
def dbapi(cls):
@@ -69,12 +68,15 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
if _is_dev_environment():
from google.appengine.api import rdbms_mysqldb
+
return rdbms_mysqldb
- elif apiproxy_stub_map.apiproxy.GetStub('rdbms'):
+ elif apiproxy_stub_map.apiproxy.GetStub("rdbms"):
from google.storage.speckle.python.api import rdbms_apiproxy
+
return rdbms_apiproxy
else:
from google.storage.speckle.python.api import rdbms_googleapi
+
return rdbms_googleapi
@classmethod
@@ -87,8 +89,8 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
if not _is_dev_environment():
# 'dsn' and 'instance' are because we are skipping
# the traditional google.api.rdbms wrapper
- opts['dsn'] = ''
- opts['instance'] = url.query['instance']
+ opts["dsn"] = ""
+ opts["instance"] = url.query["instance"]
return [], opts
def _extract_error_code(self, exception):
@@ -99,4 +101,5 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
if code:
return int(code)
+
dialect = MySQLDialect_gaerdbms
diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py
index 534fb989d..162d48f73 100644
--- a/lib/sqlalchemy/dialects/mysql/json.py
+++ b/lib/sqlalchemy/dialects/mysql/json.py
@@ -58,7 +58,6 @@ class _FormatTypeMixin(object):
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
-
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
@@ -70,8 +69,10 @@ class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
- "".join([
- "[%s]" % elem if isinstance(elem, int)
- else '."%s"' % elem for elem in value
- ])
+ "".join(
+ [
+ "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+ for elem in value
+ ]
+ )
)
diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
index e16b68bad..9c1502a14 100644
--- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
+++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
@@ -47,9 +47,13 @@ are contributed to SQLAlchemy.
"""
-from .base import (MySQLDialect, MySQLExecutionContext,
- MySQLCompiler, MySQLIdentifierPreparer,
- BIT)
+from .base import (
+ MySQLDialect,
+ MySQLExecutionContext,
+ MySQLCompiler,
+ MySQLIdentifierPreparer,
+ BIT,
+)
from ... import util
import re
@@ -57,7 +61,6 @@ from ... import processors
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
-
def get_lastrowid(self):
return self.cursor.lastrowid
@@ -65,21 +68,27 @@ class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod_binary(self, binary, operator, **kw):
if self.dialect._mysqlconnector_double_percents:
- return self.process(binary.left, **kw) + " %% " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
else:
- return self.process(binary.left, **kw) + " % " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
def post_process_text(self, text):
if self.dialect._mysqlconnector_double_percents:
- return text.replace('%', '%%')
+ return text.replace("%", "%%")
else:
return text
def escape_literal_column(self, text):
if self.dialect._mysqlconnector_double_percents:
- return text.replace('%', '%%')
+ return text.replace("%", "%%")
else:
return text
@@ -109,7 +118,7 @@ class _myconnpyBIT(BIT):
class MySQLDialect_mysqlconnector(MySQLDialect):
- driver = 'mysqlconnector'
+ driver = "mysqlconnector"
supports_unicode_binds = True
@@ -118,28 +127,22 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
supports_native_decimal = True
- default_paramstyle = 'format'
+ default_paramstyle = "format"
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
statement_compiler = MySQLCompiler_mysqlconnector
preparer = MySQLIdentifierPreparer_mysqlconnector
- colspecs = util.update_copy(
- MySQLDialect.colspecs,
- {
- BIT: _myconnpyBIT,
- }
- )
+ colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
def __init__(self, *arg, **kw):
super(MySQLDialect_mysqlconnector, self).__init__(*arg, **kw)
# hack description encoding since mysqlconnector randomly
# returns bytes or not
- self._description_decoder = \
- processors.to_conditional_unicode_processor_factory(
- self.description_encoding
- )
+ self._description_decoder = processors.to_conditional_unicode_processor_factory(
+ self.description_encoding
+ )
def _check_unicode_description(self, connection):
# hack description encoding since mysqlconnector randomly
@@ -158,6 +161,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
@classmethod
def dbapi(cls):
from mysql import connector
+
return connector
def do_ping(self, dbapi_connection):
@@ -172,54 +176,52 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
return True
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
+ opts = url.translate_connect_args(username="user")
opts.update(url.query)
- util.coerce_kw_type(opts, 'allow_local_infile', bool)
- util.coerce_kw_type(opts, 'autocommit', bool)
- util.coerce_kw_type(opts, 'buffered', bool)
- util.coerce_kw_type(opts, 'compress', bool)
- util.coerce_kw_type(opts, 'connection_timeout', int)
- util.coerce_kw_type(opts, 'connect_timeout', int)
- util.coerce_kw_type(opts, 'consume_results', bool)
- util.coerce_kw_type(opts, 'force_ipv6', bool)
- util.coerce_kw_type(opts, 'get_warnings', bool)
- util.coerce_kw_type(opts, 'pool_reset_session', bool)
- util.coerce_kw_type(opts, 'pool_size', int)
- util.coerce_kw_type(opts, 'raise_on_warnings', bool)
- util.coerce_kw_type(opts, 'raw', bool)
- util.coerce_kw_type(opts, 'ssl_verify_cert', bool)
- util.coerce_kw_type(opts, 'use_pure', bool)
- util.coerce_kw_type(opts, 'use_unicode', bool)
+ util.coerce_kw_type(opts, "allow_local_infile", bool)
+ util.coerce_kw_type(opts, "autocommit", bool)
+ util.coerce_kw_type(opts, "buffered", bool)
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "connection_timeout", int)
+ util.coerce_kw_type(opts, "connect_timeout", int)
+ util.coerce_kw_type(opts, "consume_results", bool)
+ util.coerce_kw_type(opts, "force_ipv6", bool)
+ util.coerce_kw_type(opts, "get_warnings", bool)
+ util.coerce_kw_type(opts, "pool_reset_session", bool)
+ util.coerce_kw_type(opts, "pool_size", int)
+ util.coerce_kw_type(opts, "raise_on_warnings", bool)
+ util.coerce_kw_type(opts, "raw", bool)
+ util.coerce_kw_type(opts, "ssl_verify_cert", bool)
+ util.coerce_kw_type(opts, "use_pure", bool)
+ util.coerce_kw_type(opts, "use_unicode", bool)
# unfortunately, MySQL/connector python refuses to release a
# cursor without reading fully, so non-buffered isn't an option
- opts.setdefault('buffered', True)
+ opts.setdefault("buffered", True)
# FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount.
if self.dbapi is not None:
try:
from mysql.connector.constants import ClientFlag
+
client_flags = opts.get(
- 'client_flags', ClientFlag.get_default())
+ "client_flags", ClientFlag.get_default()
+ )
client_flags |= ClientFlag.FOUND_ROWS
- opts['client_flags'] = client_flags
+ opts["client_flags"] = client_flags
except Exception:
pass
return [[], opts]
@util.memoized_property
def _mysqlconnector_version_info(self):
- if self.dbapi and hasattr(self.dbapi, '__version__'):
- m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?',
- self.dbapi.__version__)
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
- return tuple(
- int(x)
- for x in m.group(1, 2, 3)
- if x is not None)
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
@util.memoized_property
def _mysqlconnector_double_percents(self):
@@ -235,9 +237,11 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
if isinstance(e, exceptions):
- return e.errno in errnos or \
- "MySQL Connection not available." in str(e) or \
- "Connection to MySQL is not available" in str(e)
+ return (
+ e.errno in errnos
+ or "MySQL Connection not available." in str(e)
+ or "Connection to MySQL is not available" in str(e)
+ )
else:
return False
@@ -247,17 +251,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
def _compat_fetchone(self, rp, charset=None):
return rp.fetchone()
- _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
- 'READ COMMITTED', 'REPEATABLE READ',
- 'AUTOCOMMIT'])
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "AUTOCOMMIT",
+ ]
+ )
def _set_isolation_level(self, connection, level):
- if level == 'AUTOCOMMIT':
+ if level == "AUTOCOMMIT":
connection.autocommit = True
else:
connection.autocommit = False
super(MySQLDialect_mysqlconnector, self)._set_isolation_level(
- connection, level)
+ connection, level
+ )
dialect = MySQLDialect_mysqlconnector
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py
index edac816fe..6d42f5c04 100644
--- a/lib/sqlalchemy/dialects/mysql/mysqldb.py
+++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py
@@ -45,8 +45,12 @@ The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
"""
-from .base import (MySQLDialect, MySQLExecutionContext,
- MySQLCompiler, MySQLIdentifierPreparer)
+from .base import (
+ MySQLDialect,
+ MySQLExecutionContext,
+ MySQLCompiler,
+ MySQLIdentifierPreparer,
+)
from .base import TEXT
from ... import sql
from ... import util
@@ -54,10 +58,9 @@ import re
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
-
@property
def rowcount(self):
- if hasattr(self, '_rowcount'):
+ if hasattr(self, "_rowcount"):
return self._rowcount
else:
return self.cursor.rowcount
@@ -72,14 +75,14 @@ class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer):
class MySQLDialect_mysqldb(MySQLDialect):
- driver = 'mysqldb'
+ driver = "mysqldb"
supports_unicode_statements = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
- default_paramstyle = 'format'
+ default_paramstyle = "format"
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer_mysqldb
@@ -87,24 +90,23 @@ class MySQLDialect_mysqldb(MySQLDialect):
def __init__(self, server_side_cursors=False, **kwargs):
super(MySQLDialect_mysqldb, self).__init__(**kwargs)
self.server_side_cursors = server_side_cursors
- self._mysql_dbapi_version = self._parse_dbapi_version(
- self.dbapi.__version__) if self.dbapi is not None \
- and hasattr(self.dbapi, '__version__') else (0, 0, 0)
+ self._mysql_dbapi_version = (
+ self._parse_dbapi_version(self.dbapi.__version__)
+ if self.dbapi is not None and hasattr(self.dbapi, "__version__")
+ else (0, 0, 0)
+ )
def _parse_dbapi_version(self, version):
- m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', version)
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
if m:
- return tuple(
- int(x)
- for x in m.group(1, 2, 3)
- if x is not None)
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
else:
return (0, 0, 0)
@util.langhelpers.memoized_property
def supports_server_side_cursors(self):
try:
- cursors = __import__('MySQLdb.cursors').cursors
+ cursors = __import__("MySQLdb.cursors").cursors
self._sscursor = cursors.SSCursor
return True
except (ImportError, AttributeError):
@@ -112,7 +114,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
@classmethod
def dbapi(cls):
- return __import__('MySQLdb')
+ return __import__("MySQLdb")
def do_ping(self, dbapi_connection):
try:
@@ -135,67 +137,74 @@ class MySQLDialect_mysqldb(MySQLDialect):
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
# specific issue w/ the utf8mb4_bin collation and unicode returns
- has_utf8mb4_bin = self.server_version_info > (5, ) and \
- connection.scalar(
- "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
- % (
- self.identifier_preparer.quote("Charset"),
- self.identifier_preparer.quote("Collation")
- ))
+ has_utf8mb4_bin = self.server_version_info > (
+ 5,
+ ) and connection.scalar(
+ "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
+ % (
+ self.identifier_preparer.quote("Charset"),
+ self.identifier_preparer.quote("Collation"),
+ )
+ )
if has_utf8mb4_bin:
additional_tests = [
- sql.collate(sql.cast(
- sql.literal_column(
- "'test collated returns'"),
- TEXT(charset='utf8mb4')), "utf8mb4_bin")
+ sql.collate(
+ sql.cast(
+ sql.literal_column("'test collated returns'"),
+ TEXT(charset="utf8mb4"),
+ ),
+ "utf8mb4_bin",
+ )
]
else:
additional_tests = []
return super(MySQLDialect_mysqldb, self)._check_unicode_returns(
- connection, additional_tests)
+ connection, additional_tests
+ )
def create_connect_args(self, url):
- opts = url.translate_connect_args(database='db', username='user',
- password='passwd')
+ opts = url.translate_connect_args(
+ database="db", username="user", password="passwd"
+ )
opts.update(url.query)
- util.coerce_kw_type(opts, 'compress', bool)
- util.coerce_kw_type(opts, 'connect_timeout', int)
- util.coerce_kw_type(opts, 'read_timeout', int)
- util.coerce_kw_type(opts, 'write_timeout', int)
- util.coerce_kw_type(opts, 'client_flag', int)
- util.coerce_kw_type(opts, 'local_infile', int)
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "connect_timeout", int)
+ util.coerce_kw_type(opts, "read_timeout", int)
+ util.coerce_kw_type(opts, "write_timeout", int)
+ util.coerce_kw_type(opts, "client_flag", int)
+ util.coerce_kw_type(opts, "local_infile", int)
# Note: using either of the below will cause all strings to be
# returned as Unicode, both in raw SQL operations and with column
# types like String and MSString.
- util.coerce_kw_type(opts, 'use_unicode', bool)
- util.coerce_kw_type(opts, 'charset', str)
+ util.coerce_kw_type(opts, "use_unicode", bool)
+ util.coerce_kw_type(opts, "charset", str)
# Rich values 'cursorclass' and 'conv' are not supported via
# query string.
ssl = {}
- keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']
+ keys = ["ssl_ca", "ssl_key", "ssl_cert", "ssl_capath", "ssl_cipher"]
for key in keys:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
- opts['ssl'] = ssl
+ opts["ssl"] = ssl
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
- client_flag = opts.get('client_flag', 0)
+ client_flag = opts.get("client_flag", 0)
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
- self.dbapi.__name__ + '.constants.CLIENT'
+ self.dbapi.__name__ + ".constants.CLIENT"
).constants.CLIENT
client_flag |= CLIENT_FLAGS.FOUND_ROWS
except (AttributeError, ImportError):
self.supports_sane_rowcount = False
- opts['client_flag'] = client_flag
+ opts["client_flag"] = client_flag
return [[], opts]
def _extract_error_code(self, exception):
@@ -213,22 +222,30 @@ class MySQLDialect_mysqldb(MySQLDialect):
"No 'character_set_name' can be detected with "
"this MySQL-Python version; "
"please upgrade to a recent version of MySQL-Python. "
- "Assuming latin1.")
- return 'latin1'
+ "Assuming latin1."
+ )
+ return "latin1"
else:
return cset_name()
- _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
- 'READ COMMITTED', 'REPEATABLE READ',
- 'AUTOCOMMIT'])
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "AUTOCOMMIT",
+ ]
+ )
def _set_isolation_level(self, connection, level):
- if level == 'AUTOCOMMIT':
+ if level == "AUTOCOMMIT":
connection.autocommit(True)
else:
connection.autocommit(False)
- super(MySQLDialect_mysqldb, self)._set_isolation_level(connection,
- level)
+ super(MySQLDialect_mysqldb, self)._set_isolation_level(
+ connection, level
+ )
dialect = MySQLDialect_mysqldb
diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py
index 67dbb7cf2..8ba353a31 100644
--- a/lib/sqlalchemy/dialects/mysql/oursql.py
+++ b/lib/sqlalchemy/dialects/mysql/oursql.py
@@ -24,7 +24,7 @@ handling.
import re
-from .base import (BIT, MySQLDialect, MySQLExecutionContext)
+from .base import BIT, MySQLDialect, MySQLExecutionContext
from ... import types as sqltypes, util
@@ -36,14 +36,13 @@ class _oursqlBIT(BIT):
class MySQLExecutionContext_oursql(MySQLExecutionContext):
-
@property
def plain_query(self):
- return self.execution_options.get('_oursql_plain_query', False)
+ return self.execution_options.get("_oursql_plain_query", False)
class MySQLDialect_oursql(MySQLDialect):
- driver = 'oursql'
+ driver = "oursql"
if util.py2k:
supports_unicode_binds = True
@@ -56,16 +55,12 @@ class MySQLDialect_oursql(MySQLDialect):
execution_ctx_cls = MySQLExecutionContext_oursql
colspecs = util.update_copy(
- MySQLDialect.colspecs,
- {
- sqltypes.Time: sqltypes.Time,
- BIT: _oursqlBIT,
- }
+ MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _oursqlBIT}
)
@classmethod
def dbapi(cls):
- return __import__('oursql')
+ return __import__("oursql")
def do_execute(self, cursor, statement, parameters, context=None):
"""Provide an implementation of
@@ -77,7 +72,7 @@ class MySQLDialect_oursql(MySQLDialect):
cursor.execute(statement, parameters)
def do_begin(self, connection):
- connection.cursor().execute('BEGIN', plain_query=True)
+ connection.cursor().execute("BEGIN", plain_query=True)
def _xa_query(self, connection, query, xid):
if util.py2k:
@@ -85,10 +80,12 @@ class MySQLDialect_oursql(MySQLDialect):
else:
charset = self._connection_charset
arg = connection.connection._escape_string(
- xid.encode(charset)).decode(charset)
+ xid.encode(charset)
+ ).decode(charset)
arg = "'%s'" % arg
- connection.execution_options(
- _oursql_plain_query=True).execute(query % arg)
+ connection.execution_options(_oursql_plain_query=True).execute(
+ query % arg
+ )
# Because mysql is bad, these methods have to be
# reimplemented to use _PlainQuery. Basically, some queries
@@ -96,23 +93,25 @@ class MySQLDialect_oursql(MySQLDialect):
# the parameterized query API, or refuse to be parameterized
# in the first place.
def do_begin_twophase(self, connection, xid):
- self._xa_query(connection, 'XA BEGIN %s', xid)
+ self._xa_query(connection, "XA BEGIN %s", xid)
def do_prepare_twophase(self, connection, xid):
- self._xa_query(connection, 'XA END %s', xid)
- self._xa_query(connection, 'XA PREPARE %s', xid)
+ self._xa_query(connection, "XA END %s", xid)
+ self._xa_query(connection, "XA PREPARE %s", xid)
- def do_rollback_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if not is_prepared:
- self._xa_query(connection, 'XA END %s', xid)
- self._xa_query(connection, 'XA ROLLBACK %s', xid)
+ self._xa_query(connection, "XA END %s", xid)
+ self._xa_query(connection, "XA ROLLBACK %s", xid)
- def do_commit_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
- self._xa_query(connection, 'XA COMMIT %s', xid)
+ self._xa_query(connection, "XA COMMIT %s", xid)
# Q: why didn't we need all these "plain_query" overrides earlier ?
# am i on a newer/older version of OurSQL ?
@@ -121,7 +120,7 @@ class MySQLDialect_oursql(MySQLDialect):
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
- schema
+ schema,
)
def get_table_options(self, connection, table_name, schema=None, **kw):
@@ -154,7 +153,7 @@ class MySQLDialect_oursql(MySQLDialect):
return MySQLDialect.get_table_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
- schema
+ schema,
)
def get_schema_names(self, connection, **kw):
@@ -166,57 +165,69 @@ class MySQLDialect_oursql(MySQLDialect):
def initialize(self, connection):
return MySQLDialect.initialize(
- self,
- connection.execution_options(_oursql_plain_query=True)
+ self, connection.execution_options(_oursql_plain_query=True)
)
- def _show_create_table(self, connection, table, charset=None,
- full_name=None):
+ def _show_create_table(
+ self, connection, table, charset=None, full_name=None
+ ):
return MySQLDialect._show_create_table(
self,
- connection.contextual_connect(close_with_result=True).
- execution_options(_oursql_plain_query=True),
- table, charset, full_name
+ connection.contextual_connect(
+ close_with_result=True
+ ).execution_options(_oursql_plain_query=True),
+ table,
+ charset,
+ full_name,
)
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError):
- return e.errno is None and 'cursor' not in e.args[1] \
- and e.args[1].endswith('closed')
+ return (
+ e.errno is None
+ and "cursor" not in e.args[1]
+ and e.args[1].endswith("closed")
+ )
else:
return e.errno in (2006, 2013, 2014, 2045, 2055)
def create_connect_args(self, url):
- opts = url.translate_connect_args(database='db', username='user',
- password='passwd')
+ opts = url.translate_connect_args(
+ database="db", username="user", password="passwd"
+ )
opts.update(url.query)
- util.coerce_kw_type(opts, 'port', int)
- util.coerce_kw_type(opts, 'compress', bool)
- util.coerce_kw_type(opts, 'autoping', bool)
- util.coerce_kw_type(opts, 'raise_on_warnings', bool)
+ util.coerce_kw_type(opts, "port", int)
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "autoping", bool)
+ util.coerce_kw_type(opts, "raise_on_warnings", bool)
- util.coerce_kw_type(opts, 'default_charset', bool)
- if opts.pop('default_charset', False):
- opts['charset'] = None
+ util.coerce_kw_type(opts, "default_charset", bool)
+ if opts.pop("default_charset", False):
+ opts["charset"] = None
else:
- util.coerce_kw_type(opts, 'charset', str)
- opts['use_unicode'] = opts.get('use_unicode', True)
- util.coerce_kw_type(opts, 'use_unicode', bool)
+ util.coerce_kw_type(opts, "charset", str)
+ opts["use_unicode"] = opts.get("use_unicode", True)
+ util.coerce_kw_type(opts, "use_unicode", bool)
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
- opts.setdefault('found_rows', True)
+ opts.setdefault("found_rows", True)
ssl = {}
- for key in ['ssl_ca', 'ssl_key', 'ssl_cert',
- 'ssl_capath', 'ssl_cipher']:
+ for key in [
+ "ssl_ca",
+ "ssl_key",
+ "ssl_cert",
+ "ssl_capath",
+ "ssl_cipher",
+ ]:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
- opts['ssl'] = ssl
+ opts["ssl"] = ssl
return [[], opts]
diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py
index 5f176cef2..94dbfff06 100644
--- a/lib/sqlalchemy/dialects/mysql/pymysql.py
+++ b/lib/sqlalchemy/dialects/mysql/pymysql.py
@@ -34,7 +34,7 @@ from ...util import langhelpers, py3k
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
- driver = 'pymysql'
+ driver = "pymysql"
description_encoding = None
@@ -51,7 +51,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
@langhelpers.memoized_property
def supports_server_side_cursors(self):
try:
- cursors = __import__('pymysql.cursors').cursors
+ cursors = __import__("pymysql.cursors").cursors
self._sscursor = cursors.SSCursor
return True
except (ImportError, AttributeError):
@@ -59,10 +59,12 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
@classmethod
def dbapi(cls):
- return __import__('pymysql')
+ return __import__("pymysql")
def is_disconnect(self, e, connection, cursor):
- if super(MySQLDialect_pymysql, self).is_disconnect(e, connection, cursor):
+ if super(MySQLDialect_pymysql, self).is_disconnect(
+ e, connection, cursor
+ ):
return True
elif isinstance(e, self.dbapi.Error):
return "Already closed" in str(e)
@@ -70,9 +72,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
return False
if py3k:
+
def _extract_error_code(self, exception):
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]
+
dialect = MySQLDialect_pymysql
diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py
index 718754651..91512857e 100644
--- a/lib/sqlalchemy/dialects/mysql/pyodbc.py
+++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py
@@ -29,7 +29,6 @@ import re
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
-
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
@@ -46,7 +45,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
def __init__(self, **kw):
# deal with http://code.google.com/p/pyodbc/issues/detail?id=25
- kw.setdefault('convert_unicode', True)
+ kw.setdefault("convert_unicode", True)
super(MySQLDialect_pyodbc, self).__init__(**kw)
def _detect_charset(self, connection):
@@ -60,13 +59,15 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = {row[0]: row[1] for row in self._compat_fetchall(rs)}
- for key in ('character_set_connection', 'character_set'):
+ for key in ("character_set_connection", "character_set"):
if opts.get(key, None):
return opts[key]
- util.warn("Could not detect the connection character set. "
- "Assuming latin1.")
- return 'latin1'
+ util.warn(
+ "Could not detect the connection character set. "
+ "Assuming latin1."
+ )
+ return "latin1"
def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.args))
@@ -76,4 +77,5 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
else:
return None
+
dialect = MySQLDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py
index e88bc3f42..d0513eb4d 100644
--- a/lib/sqlalchemy/dialects/mysql/reflection.py
+++ b/lib/sqlalchemy/dialects/mysql/reflection.py
@@ -36,16 +36,16 @@ class MySQLTableDefinitionParser(object):
def parse(self, show_create, charset):
state = ReflectedState()
state.charset = charset
- for line in re.split(r'\r?\n', show_create):
- if line.startswith(' ' + self.preparer.initial_quote):
+ for line in re.split(r"\r?\n", show_create):
+ if line.startswith(" " + self.preparer.initial_quote):
self._parse_column(line, state)
# a regular table options line
- elif line.startswith(') '):
+ elif line.startswith(") "):
self._parse_table_options(line, state)
# an ANSI-mode table options line
- elif line == ')':
+ elif line == ")":
pass
- elif line.startswith('CREATE '):
+ elif line.startswith("CREATE "):
self._parse_table_name(line, state)
# Not present in real reflection, but may be if
# loading from a file.
@@ -55,11 +55,11 @@ class MySQLTableDefinitionParser(object):
type_, spec = self._parse_constraints(line)
if type_ is None:
util.warn("Unknown schema content: %r" % line)
- elif type_ == 'key':
+ elif type_ == "key":
state.keys.append(spec)
- elif type_ == 'fk_constraint':
+ elif type_ == "fk_constraint":
state.fk_constraints.append(spec)
- elif type_ == 'ck_constraint':
+ elif type_ == "ck_constraint":
state.ck_constraints.append(spec)
else:
pass
@@ -78,39 +78,39 @@ class MySQLTableDefinitionParser(object):
# convert columns into name, length pairs
# NOTE: we may want to consider SHOW INDEX as the
# format of indexes in MySQL becomes more complex
- spec['columns'] = self._parse_keyexprs(spec['columns'])
- if spec['version_sql']:
- m2 = self._re_key_version_sql.match(spec['version_sql'])
- if m2 and m2.groupdict()['parser']:
- spec['parser'] = m2.groupdict()['parser']
- if spec['parser']:
- spec['parser'] = self.preparer.unformat_identifiers(
- spec['parser'])[0]
- return 'key', spec
+ spec["columns"] = self._parse_keyexprs(spec["columns"])
+ if spec["version_sql"]:
+ m2 = self._re_key_version_sql.match(spec["version_sql"])
+ if m2 and m2.groupdict()["parser"]:
+ spec["parser"] = m2.groupdict()["parser"]
+ if spec["parser"]:
+ spec["parser"] = self.preparer.unformat_identifiers(
+ spec["parser"]
+ )[0]
+ return "key", spec
# FOREIGN KEY CONSTRAINT
m = self._re_fk_constraint.match(line)
if m:
spec = m.groupdict()
- spec['table'] = \
- self.preparer.unformat_identifiers(spec['table'])
- spec['local'] = [c[0]
- for c in self._parse_keyexprs(spec['local'])]
- spec['foreign'] = [c[0]
- for c in self._parse_keyexprs(spec['foreign'])]
- return 'fk_constraint', spec
+ spec["table"] = self.preparer.unformat_identifiers(spec["table"])
+ spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])]
+ spec["foreign"] = [
+ c[0] for c in self._parse_keyexprs(spec["foreign"])
+ ]
+ return "fk_constraint", spec
# CHECK constraint
m = self._re_ck_constraint.match(line)
if m:
spec = m.groupdict()
- return 'ck_constraint', spec
+ return "ck_constraint", spec
# PARTITION and SUBPARTITION
m = self._re_partition.match(line)
if m:
# Punt!
- return 'partition', line
+ return "partition", line
# No match.
return (None, line)
@@ -124,7 +124,7 @@ class MySQLTableDefinitionParser(object):
regex, cleanup = self._pr_name
m = regex.match(line)
if m:
- state.table_name = cleanup(m.group('name'))
+ state.table_name = cleanup(m.group("name"))
def _parse_table_options(self, line, state):
"""Build a dictionary of all reflected table-level options.
@@ -134,7 +134,7 @@ class MySQLTableDefinitionParser(object):
options = {}
- if not line or line == ')':
+ if not line or line == ")":
pass
else:
@@ -143,17 +143,17 @@ class MySQLTableDefinitionParser(object):
m = regex.search(rest_of_line)
if not m:
continue
- directive, value = m.group('directive'), m.group('val')
+ directive, value = m.group("directive"), m.group("val")
if cleanup:
value = cleanup(value)
options[directive.lower()] = value
- rest_of_line = regex.sub('', rest_of_line)
+ rest_of_line = regex.sub("", rest_of_line)
- for nope in ('auto_increment', 'data directory', 'index directory'):
+ for nope in ("auto_increment", "data directory", "index directory"):
options.pop(nope, None)
for opt, val in options.items():
- state.table_options['%s_%s' % (self.dialect.name, opt)] = val
+ state.table_options["%s_%s" % (self.dialect.name, opt)] = val
def _parse_column(self, line, state):
"""Extract column details.
@@ -167,29 +167,30 @@ class MySQLTableDefinitionParser(object):
m = self._re_column.match(line)
if m:
spec = m.groupdict()
- spec['full'] = True
+ spec["full"] = True
else:
m = self._re_column_loose.match(line)
if m:
spec = m.groupdict()
- spec['full'] = False
+ spec["full"] = False
if not spec:
util.warn("Unknown column definition %r" % line)
return
- if not spec['full']:
+ if not spec["full"]:
util.warn("Incomplete reflection of column definition %r" % line)
- name, type_, args = spec['name'], spec['coltype'], spec['arg']
+ name, type_, args = spec["name"], spec["coltype"], spec["arg"]
try:
col_type = self.dialect.ischema_names[type_]
except KeyError:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (type_, name))
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (type_, name)
+ )
col_type = sqltypes.NullType
# Column type positional arguments eg. varchar(32)
- if args is None or args == '':
+ if args is None or args == "":
type_args = []
elif args[0] == "'" and args[-1] == "'":
type_args = self._re_csv_str.findall(args)
@@ -201,50 +202,51 @@ class MySQLTableDefinitionParser(object):
if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
if type_args:
- type_kw['fsp'] = type_args.pop(0)
+ type_kw["fsp"] = type_args.pop(0)
- for kw in ('unsigned', 'zerofill'):
+ for kw in ("unsigned", "zerofill"):
if spec.get(kw, False):
type_kw[kw] = True
- for kw in ('charset', 'collate'):
+ for kw in ("charset", "collate"):
if spec.get(kw, False):
type_kw[kw] = spec[kw]
if issubclass(col_type, _EnumeratedValues):
type_args = _EnumeratedValues._strip_values(type_args)
- if issubclass(col_type, SET) and '' in type_args:
- type_kw['retrieve_as_bitwise'] = True
+ if issubclass(col_type, SET) and "" in type_args:
+ type_kw["retrieve_as_bitwise"] = True
type_instance = col_type(*type_args, **type_kw)
col_kw = {}
# NOT NULL
- col_kw['nullable'] = True
+ col_kw["nullable"] = True
# this can be "NULL" in the case of TIMESTAMP
- if spec.get('notnull', False) == 'NOT NULL':
- col_kw['nullable'] = False
+ if spec.get("notnull", False) == "NOT NULL":
+ col_kw["nullable"] = False
# AUTO_INCREMENT
- if spec.get('autoincr', False):
- col_kw['autoincrement'] = True
+ if spec.get("autoincr", False):
+ col_kw["autoincrement"] = True
elif issubclass(col_type, sqltypes.Integer):
- col_kw['autoincrement'] = False
+ col_kw["autoincrement"] = False
# DEFAULT
- default = spec.get('default', None)
+ default = spec.get("default", None)
- if default == 'NULL':
+ if default == "NULL":
# eliminates the need to deal with this later.
default = None
- comment = spec.get('comment', None)
+ comment = spec.get("comment", None)
if comment is not None:
comment = comment.replace("\\\\", "\\").replace("''", "'")
- col_d = dict(name=name, type=type_instance, default=default,
- comment=comment)
+ col_d = dict(
+ name=name, type=type_instance, default=default, comment=comment
+ )
col_d.update(col_kw)
state.columns.append(col_d)
@@ -262,36 +264,44 @@ class MySQLTableDefinitionParser(object):
buffer = []
for row in columns:
- (name, col_type, nullable, default, extra) = \
- [row[i] for i in (0, 1, 2, 4, 5)]
+ (name, col_type, nullable, default, extra) = [
+ row[i] for i in (0, 1, 2, 4, 5)
+ ]
- line = [' ']
+ line = [" "]
line.append(self.preparer.quote_identifier(name))
line.append(col_type)
if not nullable:
- line.append('NOT NULL')
+ line.append("NOT NULL")
if default:
- if 'auto_increment' in default:
+ if "auto_increment" in default:
pass
- elif (col_type.startswith('timestamp') and
- default.startswith('C')):
- line.append('DEFAULT')
+ elif col_type.startswith("timestamp") and default.startswith(
+ "C"
+ ):
+ line.append("DEFAULT")
line.append(default)
- elif default == 'NULL':
- line.append('DEFAULT')
+ elif default == "NULL":
+ line.append("DEFAULT")
line.append(default)
else:
- line.append('DEFAULT')
+ line.append("DEFAULT")
line.append("'%s'" % default.replace("'", "''"))
if extra:
line.append(extra)
- buffer.append(' '.join(line))
-
- return ''.join([('CREATE TABLE %s (\n' %
- self.preparer.quote_identifier(table_name)),
- ',\n'.join(buffer),
- '\n) '])
+ buffer.append(" ".join(line))
+
+ return "".join(
+ [
+ (
+ "CREATE TABLE %s (\n"
+ % self.preparer.quote_identifier(table_name)
+ ),
+ ",\n".join(buffer),
+ "\n) ",
+ ]
+ )
def _parse_keyexprs(self, identifiers):
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
@@ -306,29 +316,39 @@ class MySQLTableDefinitionParser(object):
_final = self.preparer.final_quote
- quotes = dict(zip(('iq', 'fq', 'esc_fq'),
- [re.escape(s) for s in
- (self.preparer.initial_quote,
- _final,
- self.preparer._escape_identifier(_final))]))
+ quotes = dict(
+ zip(
+ ("iq", "fq", "esc_fq"),
+ [
+ re.escape(s)
+ for s in (
+ self.preparer.initial_quote,
+ _final,
+ self.preparer._escape_identifier(_final),
+ )
+ ],
+ )
+ )
self._pr_name = _pr_compile(
- r'^CREATE (?:\w+ +)?TABLE +'
- r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes,
- self.preparer._unescape_identifier)
+ r"^CREATE (?:\w+ +)?TABLE +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes,
+ self.preparer._unescape_identifier,
+ )
# `col`,`col2`(32),`col3`(15) DESC
#
self._re_keyexprs = _re_compile(
- r'(?:'
- r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)'
- r'(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+' % quotes)
+ r"(?:"
+ r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)"
+ r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes
+ )
# 'foo' or 'foo','bar' or 'fo,o','ba''a''r'
- self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27')
+ self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27")
# 123 or 123,456
- self._re_csv_int = _re_compile(r'\d+')
+ self._re_csv_int = _re_compile(r"\d+")
# `colname` <type> [type opts]
# (NOT NULL | NULL)
@@ -356,43 +376,39 @@ class MySQLTableDefinitionParser(object):
r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
r"(?: +STORAGE +(?P<storage>\w+))?"
r"(?: +(?P<extra>.*))?"
- r",?$"
- % quotes
+ r",?$" % quotes
)
# Fallback, try to parse as little as possible
self._re_column_loose = _re_compile(
- r' '
- r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
- r'(?P<coltype>\w+)'
- r'(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?'
- r'.*?(?P<notnull>(?:NOT )NULL)?'
- % quotes
+ r" "
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"(?P<coltype>\w+)"
+ r"(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?"
+ r".*?(?P<notnull>(?:NOT )NULL)?" % quotes
)
# (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))?
# (`col` (ASC|DESC)?, `col` (ASC|DESC)?)
# KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */
self._re_key = _re_compile(
- r' '
- r'(?:(?P<type>\S+) )?KEY'
- r'(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?'
- r'(?: +USING +(?P<using_pre>\S+))?'
- r' +\((?P<columns>.+?)\)'
- r'(?: +USING +(?P<using_post>\S+))?'
- r'(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?'
- r'(?: +WITH PARSER +(?P<parser>\S+))?'
- r'(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?'
- r'(?: +/\*(?P<version_sql>.+)\*/ +)?'
- r',?$'
- % quotes
+ r" "
+ r"(?:(?P<type>\S+) )?KEY"
+ r"(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?"
+ r"(?: +USING +(?P<using_pre>\S+))?"
+ r" +\((?P<columns>.+?)\)"
+ r"(?: +USING +(?P<using_post>\S+))?"
+ r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?"
+ r"(?: +WITH PARSER +(?P<parser>\S+))?"
+ r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
+ r"(?: +/\*(?P<version_sql>.+)\*/ +)?"
+ r",?$" % quotes
)
# https://forums.mysql.com/read.php?20,567102,567111#msg-567111
# It means if the MySQL version >= \d+, execute what's in the comment
self._re_key_version_sql = _re_compile(
- r'\!\d+ '
- r'(?: *WITH PARSER +(?P<parser>\S+) *)?'
+ r"\!\d+ " r"(?: *WITH PARSER +(?P<parser>\S+) *)?"
)
# CONSTRAINT `name` FOREIGN KEY (`local_col`)
@@ -402,20 +418,19 @@ class MySQLTableDefinitionParser(object):
#
# unique constraints come back as KEYs
kw = quotes.copy()
- kw['on'] = 'RESTRICT|CASCADE|SET NULL|NOACTION'
+ kw["on"] = "RESTRICT|CASCADE|SET NULL|NOACTION"
self._re_fk_constraint = _re_compile(
- r' '
- r'CONSTRAINT +'
- r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
- r'FOREIGN KEY +'
- r'\((?P<local>[^\)]+?)\) REFERENCES +'
- r'(?P<table>%(iq)s[^%(fq)s]+%(fq)s'
- r'(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +'
- r'\((?P<foreign>[^\)]+?)\)'
- r'(?: +(?P<match>MATCH \w+))?'
- r'(?: +ON DELETE (?P<ondelete>%(on)s))?'
- r'(?: +ON UPDATE (?P<onupdate>%(on)s))?'
- % kw
+ r" "
+ r"CONSTRAINT +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"FOREIGN KEY +"
+ r"\((?P<local>[^\)]+?)\) REFERENCES +"
+ r"(?P<table>%(iq)s[^%(fq)s]+%(fq)s"
+ r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +"
+ r"\((?P<foreign>[^\)]+?)\)"
+ r"(?: +(?P<match>MATCH \w+))?"
+ r"(?: +ON DELETE (?P<ondelete>%(on)s))?"
+ r"(?: +ON UPDATE (?P<onupdate>%(on)s))?" % kw
)
# CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)'
@@ -423,18 +438,17 @@ class MySQLTableDefinitionParser(object):
# is returned on a line by itself, so to match without worrying
# about parenthesis in the expresion we go to the end of the line
self._re_ck_constraint = _re_compile(
- r' '
- r'CONSTRAINT +'
- r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
- r'CHECK +'
- r'\((?P<sqltext>.+)\),?'
- % kw
+ r" "
+ r"CONSTRAINT +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"CHECK +"
+ r"\((?P<sqltext>.+)\),?" % kw
)
# PARTITION
#
# punt!
- self._re_partition = _re_compile(r'(?:.*)(?:SUB)?PARTITION(?:.*)')
+ self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)")
# Table-level options (COLLATE, ENGINE, etc.)
# Do the string options first, since they have quoted
@@ -442,44 +456,68 @@ class MySQLTableDefinitionParser(object):
for option in _options_of_type_string:
self._add_option_string(option)
- for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT',
- 'AVG_ROW_LENGTH', 'CHARACTER SET',
- 'DEFAULT CHARSET', 'CHECKSUM',
- 'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD',
- 'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT',
- 'KEY_BLOCK_SIZE'):
+ for option in (
+ "ENGINE",
+ "TYPE",
+ "AUTO_INCREMENT",
+ "AVG_ROW_LENGTH",
+ "CHARACTER SET",
+ "DEFAULT CHARSET",
+ "CHECKSUM",
+ "COLLATE",
+ "DELAY_KEY_WRITE",
+ "INSERT_METHOD",
+ "MAX_ROWS",
+ "MIN_ROWS",
+ "PACK_KEYS",
+ "ROW_FORMAT",
+ "KEY_BLOCK_SIZE",
+ ):
self._add_option_word(option)
- self._add_option_regex('UNION', r'\([^\)]+\)')
- self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK')
+ self._add_option_regex("UNION", r"\([^\)]+\)")
+ self._add_option_regex("TABLESPACE", r".*? STORAGE DISK")
self._add_option_regex(
- 'RAID_TYPE',
- r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+')
+ "RAID_TYPE",
+ r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+",
+ )
- _optional_equals = r'(?:\s*(?:=\s*)|\s+)'
+ _optional_equals = r"(?:\s*(?:=\s*)|\s+)"
def _add_option_string(self, directive):
- regex = (r'(?P<directive>%s)%s'
- r"'(?P<val>(?:[^']|'')*?)'(?!')" %
- (re.escape(directive), self._optional_equals))
- self._pr_options.append(_pr_compile(
- regex, lambda v: v.replace("\\\\", "\\").replace("''", "'")
- ))
+ regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
+ re.escape(directive),
+ self._optional_equals,
+ )
+ self._pr_options.append(
+ _pr_compile(
+ regex, lambda v: v.replace("\\\\", "\\").replace("''", "'")
+ )
+ )
def _add_option_word(self, directive):
- regex = (r'(?P<directive>%s)%s'
- r'(?P<val>\w+)' %
- (re.escape(directive), self._optional_equals))
+ regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
+ re.escape(directive),
+ self._optional_equals,
+ )
self._pr_options.append(_pr_compile(regex))
def _add_option_regex(self, directive, regex):
- regex = (r'(?P<directive>%s)%s'
- r'(?P<val>%s)' %
- (re.escape(directive), self._optional_equals, regex))
+ regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
+ re.escape(directive),
+ self._optional_equals,
+ regex,
+ )
self._pr_options.append(_pr_compile(regex))
-_options_of_type_string = ('COMMENT', 'DATA DIRECTORY', 'INDEX DIRECTORY',
- 'PASSWORD', 'CONNECTION')
+
+_options_of_type_string = (
+ "COMMENT",
+ "DATA DIRECTORY",
+ "INDEX DIRECTORY",
+ "PASSWORD",
+ "CONNECTION",
+)
def _pr_compile(regex, cleanup=None):
diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py
index cb09a0841..ad97a9bbe 100644
--- a/lib/sqlalchemy/dialects/mysql/types.py
+++ b/lib/sqlalchemy/dialects/mysql/types.py
@@ -24,28 +24,30 @@ class _NumericType(object):
super(_NumericType, self).__init__(**kw)
def __repr__(self):
- return util.generic_repr(self,
- to_inspect=[_NumericType, sqltypes.Numeric])
+ return util.generic_repr(
+ self, to_inspect=[_NumericType, sqltypes.Numeric]
+ )
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
- if isinstance(self, (REAL, DOUBLE)) and \
- (
- (precision is None and scale is not None) or
- (precision is not None and scale is None)
+ if isinstance(self, (REAL, DOUBLE)) and (
+ (precision is None and scale is not None)
+ or (precision is not None and scale is None)
):
raise exc.ArgumentError(
"You must specify both precision and scale or omit "
- "both altogether.")
+ "both altogether."
+ )
super(_FloatType, self).__init__(
- precision=precision, asdecimal=asdecimal, **kw)
+ precision=precision, asdecimal=asdecimal, **kw
+ )
self.scale = scale
def __repr__(self):
- return util.generic_repr(self, to_inspect=[_FloatType,
- _NumericType,
- sqltypes.Float])
+ return util.generic_repr(
+ self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
+ )
class _IntegerType(_NumericType, sqltypes.Integer):
@@ -54,21 +56,28 @@ class _IntegerType(_NumericType, sqltypes.Integer):
super(_IntegerType, self).__init__(**kw)
def __repr__(self):
- return util.generic_repr(self, to_inspect=[_IntegerType,
- _NumericType,
- sqltypes.Integer])
+ return util.generic_repr(
+ self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
+ )
class _StringType(sqltypes.String):
"""Base for MySQL string types."""
- def __init__(self, charset=None, collation=None,
- ascii=False, binary=False, unicode=False,
- national=False, **kw):
+ def __init__(
+ self,
+ charset=None,
+ collation=None,
+ ascii=False,
+ binary=False,
+ unicode=False,
+ national=False,
+ **kw
+ ):
self.charset = charset
# allow collate= or collation=
- kw.setdefault('collation', kw.pop('collate', collation))
+ kw.setdefault("collation", kw.pop("collate", collation))
self.ascii = ascii
self.unicode = unicode
@@ -77,8 +86,9 @@ class _StringType(sqltypes.String):
super(_StringType, self).__init__(**kw)
def __repr__(self):
- return util.generic_repr(self,
- to_inspect=[_StringType, sqltypes.String])
+ return util.generic_repr(
+ self, to_inspect=[_StringType, sqltypes.String]
+ )
class _MatchType(sqltypes.Float, sqltypes.MatchType):
@@ -88,11 +98,10 @@ class _MatchType(sqltypes.Float, sqltypes.MatchType):
sqltypes.MatchType.__init__(self)
-
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""MySQL NUMERIC type."""
- __visit_name__ = 'NUMERIC'
+ __visit_name__ = "NUMERIC"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
@@ -110,14 +119,15 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC):
numeric.
"""
- super(NUMERIC, self).__init__(precision=precision,
- scale=scale, asdecimal=asdecimal, **kw)
+ super(NUMERIC, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""MySQL DECIMAL type."""
- __visit_name__ = 'DECIMAL'
+ __visit_name__ = "DECIMAL"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
@@ -135,14 +145,15 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL):
numeric.
"""
- super(DECIMAL, self).__init__(precision=precision, scale=scale,
- asdecimal=asdecimal, **kw)
+ super(DECIMAL, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
class DOUBLE(_FloatType):
"""MySQL DOUBLE type."""
- __visit_name__ = 'DOUBLE'
+ __visit_name__ = "DOUBLE"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
@@ -168,14 +179,15 @@ class DOUBLE(_FloatType):
numeric.
"""
- super(DOUBLE, self).__init__(precision=precision, scale=scale,
- asdecimal=asdecimal, **kw)
+ super(DOUBLE, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
class REAL(_FloatType, sqltypes.REAL):
"""MySQL REAL type."""
- __visit_name__ = 'REAL'
+ __visit_name__ = "REAL"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
@@ -201,14 +213,15 @@ class REAL(_FloatType, sqltypes.REAL):
numeric.
"""
- super(REAL, self).__init__(precision=precision, scale=scale,
- asdecimal=asdecimal, **kw)
+ super(REAL, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
class FLOAT(_FloatType, sqltypes.FLOAT):
"""MySQL FLOAT type."""
- __visit_name__ = 'FLOAT'
+ __visit_name__ = "FLOAT"
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
@@ -226,8 +239,9 @@ class FLOAT(_FloatType, sqltypes.FLOAT):
numeric.
"""
- super(FLOAT, self).__init__(precision=precision, scale=scale,
- asdecimal=asdecimal, **kw)
+ super(FLOAT, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
def bind_processor(self, dialect):
return None
@@ -236,7 +250,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT):
class INTEGER(_IntegerType, sqltypes.INTEGER):
"""MySQL INTEGER type."""
- __visit_name__ = 'INTEGER'
+ __visit_name__ = "INTEGER"
def __init__(self, display_width=None, **kw):
"""Construct an INTEGER.
@@ -257,7 +271,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER):
class BIGINT(_IntegerType, sqltypes.BIGINT):
"""MySQL BIGINTEGER type."""
- __visit_name__ = 'BIGINT'
+ __visit_name__ = "BIGINT"
def __init__(self, display_width=None, **kw):
"""Construct a BIGINTEGER.
@@ -278,7 +292,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT):
class MEDIUMINT(_IntegerType):
"""MySQL MEDIUMINTEGER type."""
- __visit_name__ = 'MEDIUMINT'
+ __visit_name__ = "MEDIUMINT"
def __init__(self, display_width=None, **kw):
"""Construct a MEDIUMINTEGER
@@ -299,7 +313,7 @@ class MEDIUMINT(_IntegerType):
class TINYINT(_IntegerType):
"""MySQL TINYINT type."""
- __visit_name__ = 'TINYINT'
+ __visit_name__ = "TINYINT"
def __init__(self, display_width=None, **kw):
"""Construct a TINYINT.
@@ -320,7 +334,7 @@ class TINYINT(_IntegerType):
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
"""MySQL SMALLINTEGER type."""
- __visit_name__ = 'SMALLINT'
+ __visit_name__ = "SMALLINT"
def __init__(self, display_width=None, **kw):
"""Construct a SMALLINTEGER.
@@ -347,7 +361,7 @@ class BIT(sqltypes.TypeEngine):
"""
- __visit_name__ = 'BIT'
+ __visit_name__ = "BIT"
def __init__(self, length=None):
"""Construct a BIT.
@@ -374,13 +388,14 @@ class BIT(sqltypes.TypeEngine):
v = v << 8 | i
return v
return value
+
return process
class TIME(sqltypes.TIME):
"""MySQL TIME type. """
- __visit_name__ = 'TIME'
+ __visit_name__ = "TIME"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIME type.
@@ -413,12 +428,15 @@ class TIME(sqltypes.TIME):
microseconds = value.microseconds
seconds = value.seconds
minutes = seconds // 60
- return time(minutes // 60,
- minutes % 60,
- seconds - minutes * 60,
- microsecond=microseconds)
+ return time(
+ minutes // 60,
+ minutes % 60,
+ seconds - minutes * 60,
+ microsecond=microseconds,
+ )
else:
return None
+
return process
@@ -427,7 +445,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
"""
- __visit_name__ = 'TIMESTAMP'
+ __visit_name__ = "TIMESTAMP"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIMESTAMP type.
@@ -457,7 +475,7 @@ class DATETIME(sqltypes.DATETIME):
"""
- __visit_name__ = 'DATETIME'
+ __visit_name__ = "DATETIME"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL DATETIME type.
@@ -485,7 +503,7 @@ class DATETIME(sqltypes.DATETIME):
class YEAR(sqltypes.TypeEngine):
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
- __visit_name__ = 'YEAR'
+ __visit_name__ = "YEAR"
def __init__(self, display_width=None):
self.display_width = display_width
@@ -494,7 +512,7 @@ class YEAR(sqltypes.TypeEngine):
class TEXT(_StringType, sqltypes.TEXT):
"""MySQL TEXT type, for text up to 2^16 characters."""
- __visit_name__ = 'TEXT'
+ __visit_name__ = "TEXT"
def __init__(self, length=None, **kw):
"""Construct a TEXT.
@@ -530,7 +548,7 @@ class TEXT(_StringType, sqltypes.TEXT):
class TINYTEXT(_StringType):
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
- __visit_name__ = 'TINYTEXT'
+ __visit_name__ = "TINYTEXT"
def __init__(self, **kwargs):
"""Construct a TINYTEXT.
@@ -562,7 +580,7 @@ class TINYTEXT(_StringType):
class MEDIUMTEXT(_StringType):
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
- __visit_name__ = 'MEDIUMTEXT'
+ __visit_name__ = "MEDIUMTEXT"
def __init__(self, **kwargs):
"""Construct a MEDIUMTEXT.
@@ -594,7 +612,7 @@ class MEDIUMTEXT(_StringType):
class LONGTEXT(_StringType):
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
- __visit_name__ = 'LONGTEXT'
+ __visit_name__ = "LONGTEXT"
def __init__(self, **kwargs):
"""Construct a LONGTEXT.
@@ -626,7 +644,7 @@ class LONGTEXT(_StringType):
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""MySQL VARCHAR type, for variable-length character data."""
- __visit_name__ = 'VARCHAR'
+ __visit_name__ = "VARCHAR"
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
@@ -658,7 +676,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
class CHAR(_StringType, sqltypes.CHAR):
"""MySQL CHAR type, for fixed-length character data."""
- __visit_name__ = 'CHAR'
+ __visit_name__ = "CHAR"
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
@@ -690,7 +708,7 @@ class CHAR(_StringType, sqltypes.CHAR):
ascii=type_.ascii,
binary=type_.binary,
unicode=type_.unicode,
- national=False # not supported in CAST
+ national=False, # not supported in CAST
)
else:
return CHAR(length=type_.length)
@@ -703,7 +721,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
character set.
"""
- __visit_name__ = 'NVARCHAR'
+ __visit_name__ = "NVARCHAR"
def __init__(self, length=None, **kwargs):
"""Construct an NVARCHAR.
@@ -718,7 +736,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
compatible with the national character set.
"""
- kwargs['national'] = True
+ kwargs["national"] = True
super(NVARCHAR, self).__init__(length=length, **kwargs)
@@ -729,7 +747,7 @@ class NCHAR(_StringType, sqltypes.NCHAR):
character set.
"""
- __visit_name__ = 'NCHAR'
+ __visit_name__ = "NCHAR"
def __init__(self, length=None, **kwargs):
"""Construct an NCHAR.
@@ -744,23 +762,23 @@ class NCHAR(_StringType, sqltypes.NCHAR):
compatible with the national character set.
"""
- kwargs['national'] = True
+ kwargs["national"] = True
super(NCHAR, self).__init__(length=length, **kwargs)
class TINYBLOB(sqltypes._Binary):
"""MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
- __visit_name__ = 'TINYBLOB'
+ __visit_name__ = "TINYBLOB"
class MEDIUMBLOB(sqltypes._Binary):
"""MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
- __visit_name__ = 'MEDIUMBLOB'
+ __visit_name__ = "MEDIUMBLOB"
class LONGBLOB(sqltypes._Binary):
"""MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
- __visit_name__ = 'LONGBLOB'
+ __visit_name__ = "LONGBLOB"
diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py
index 4aee2dbb7..d8ee43748 100644
--- a/lib/sqlalchemy/dialects/mysql/zxjdbc.py
+++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py
@@ -37,6 +37,7 @@ from .base import BIT, MySQLDialect, MySQLExecutionContext
class _ZxJDBCBit(BIT):
def result_processor(self, dialect, coltype):
"""Converts boolean or byte arrays from MySQL Connector/J to longs."""
+
def process(value):
if value is None:
return value
@@ -44,9 +45,10 @@ class _ZxJDBCBit(BIT):
return int(value)
v = 0
for i in value:
- v = v << 8 | (i & 0xff)
+ v = v << 8 | (i & 0xFF)
value = v
return value
+
return process
@@ -60,17 +62,13 @@ class MySQLExecutionContext_zxjdbc(MySQLExecutionContext):
class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
- jdbc_db_name = 'mysql'
- jdbc_driver_name = 'com.mysql.jdbc.Driver'
+ jdbc_db_name = "mysql"
+ jdbc_driver_name = "com.mysql.jdbc.Driver"
execution_ctx_cls = MySQLExecutionContext_zxjdbc
colspecs = util.update_copy(
- MySQLDialect.colspecs,
- {
- sqltypes.Time: sqltypes.Time,
- BIT: _ZxJDBCBit
- }
+ MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _ZxJDBCBit}
)
def _detect_charset(self, connection):
@@ -83,17 +81,19 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = {row[0]: row[1] for row in self._compat_fetchall(rs)}
- for key in ('character_set_connection', 'character_set'):
+ for key in ("character_set_connection", "character_set"):
if opts.get(key, None):
return opts[key]
- util.warn("Could not detect the connection character set. "
- "Assuming latin1.")
- return 'latin1'
+ util.warn(
+ "Could not detect the connection character set. "
+ "Assuming latin1."
+ )
+ return "latin1"
def _driver_kwargs(self):
"""return kw arg dict to be sent to connect()."""
- return dict(characterEncoding='UTF-8', yearIsDateType='false')
+ return dict(characterEncoding="UTF-8", yearIsDateType="false")
def _extract_error_code(self, exception):
# e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
@@ -106,7 +106,7 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
- r = re.compile(r'[.\-]')
+ r = re.compile(r"[.\-]")
for n in r.split(dbapi_con.dbversion):
try:
version.append(int(n))
@@ -114,4 +114,5 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
version.append(n)
return tuple(version)
+
dialect = MySQLDialect_zxjdbc
diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py
index e3d9fed2c..1b9007fcc 100644
--- a/lib/sqlalchemy/dialects/oracle/__init__.py
+++ b/lib/sqlalchemy/dialects/oracle/__init__.py
@@ -7,18 +7,51 @@
from . import base, cx_oracle, zxjdbc # noqa
-from .base import \
- VARCHAR, NVARCHAR, CHAR, DATE, NUMBER,\
- BLOB, BFILE, BINARY_FLOAT, BINARY_DOUBLE, CLOB, NCLOB, TIMESTAMP, RAW,\
- FLOAT, DOUBLE_PRECISION, LONG, INTERVAL,\
- VARCHAR2, NVARCHAR2, ROWID
+from .base import (
+ VARCHAR,
+ NVARCHAR,
+ CHAR,
+ DATE,
+ NUMBER,
+ BLOB,
+ BFILE,
+ BINARY_FLOAT,
+ BINARY_DOUBLE,
+ CLOB,
+ NCLOB,
+ TIMESTAMP,
+ RAW,
+ FLOAT,
+ DOUBLE_PRECISION,
+ LONG,
+ INTERVAL,
+ VARCHAR2,
+ NVARCHAR2,
+ ROWID,
+)
base.dialect = dialect = cx_oracle.dialect
__all__ = (
- 'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER',
- 'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
- 'FLOAT', 'DOUBLE_PRECISION', 'BINARY_DOUBLE', 'BINARY_FLOAT',
- 'LONG', 'dialect', 'INTERVAL',
- 'VARCHAR2', 'NVARCHAR2', 'ROWID'
+ "VARCHAR",
+ "NVARCHAR",
+ "CHAR",
+ "DATE",
+ "NUMBER",
+ "BLOB",
+ "BFILE",
+ "CLOB",
+ "NCLOB",
+ "TIMESTAMP",
+ "RAW",
+ "FLOAT",
+ "DOUBLE_PRECISION",
+ "BINARY_DOUBLE",
+ "BINARY_FLOAT",
+ "LONG",
+ "dialect",
+ "INTERVAL",
+ "VARCHAR2",
+ "NVARCHAR2",
+ "ROWID",
)
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index b5aea4386..944fe21c3 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -353,49 +353,63 @@ from sqlalchemy.sql import compiler, visitors, expression, util as sql_util
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy import types as sqltypes, schema as sa_schema
-from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \
- BLOB, CLOB, TIMESTAMP, FLOAT, INTEGER
+from sqlalchemy.types import (
+ VARCHAR,
+ NVARCHAR,
+ CHAR,
+ BLOB,
+ CLOB,
+ TIMESTAMP,
+ FLOAT,
+ INTEGER,
+)
from itertools import groupby
-RESERVED_WORDS = \
- set('SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN '
- 'DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED '
- 'ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE '
- 'ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE '
- 'BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES '
- 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS '
- 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER '
- 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR '
- 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split())
+RESERVED_WORDS = set(
+ "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN "
+ "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED "
+ "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE "
+ "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE "
+ "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES "
+ "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS "
+ "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER "
+ "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR "
+ "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split()
+)
-NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER '
- 'CURRENT_TIME CURRENT_TIMESTAMP'.split())
+NO_ARG_FNS = set(
+ "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split()
+)
class RAW(sqltypes._Binary):
- __visit_name__ = 'RAW'
+ __visit_name__ = "RAW"
+
+
OracleRaw = RAW
class NCLOB(sqltypes.Text):
- __visit_name__ = 'NCLOB'
+ __visit_name__ = "NCLOB"
class VARCHAR2(VARCHAR):
- __visit_name__ = 'VARCHAR2'
+ __visit_name__ = "VARCHAR2"
+
NVARCHAR2 = NVARCHAR
class NUMBER(sqltypes.Numeric, sqltypes.Integer):
- __visit_name__ = 'NUMBER'
+ __visit_name__ = "NUMBER"
def __init__(self, precision=None, scale=None, asdecimal=None):
if asdecimal is None:
asdecimal = bool(scale and scale > 0)
super(NUMBER, self).__init__(
- precision=precision, scale=scale, asdecimal=asdecimal)
+ precision=precision, scale=scale, asdecimal=asdecimal
+ )
def adapt(self, impltype):
ret = super(NUMBER, self).adapt(impltype)
@@ -412,23 +426,23 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer):
class DOUBLE_PRECISION(sqltypes.Float):
- __visit_name__ = 'DOUBLE_PRECISION'
+ __visit_name__ = "DOUBLE_PRECISION"
class BINARY_DOUBLE(sqltypes.Float):
- __visit_name__ = 'BINARY_DOUBLE'
+ __visit_name__ = "BINARY_DOUBLE"
class BINARY_FLOAT(sqltypes.Float):
- __visit_name__ = 'BINARY_FLOAT'
+ __visit_name__ = "BINARY_FLOAT"
class BFILE(sqltypes.LargeBinary):
- __visit_name__ = 'BFILE'
+ __visit_name__ = "BFILE"
class LONG(sqltypes.Text):
- __visit_name__ = 'LONG'
+ __visit_name__ = "LONG"
class DATE(sqltypes.DateTime):
@@ -441,18 +455,17 @@ class DATE(sqltypes.DateTime):
.. versionadded:: 0.9.4
"""
- __visit_name__ = 'DATE'
+
+ __visit_name__ = "DATE"
def _compare_type_affinity(self, other):
return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
class INTERVAL(sqltypes.TypeEngine):
- __visit_name__ = 'INTERVAL'
+ __visit_name__ = "INTERVAL"
- def __init__(self,
- day_precision=None,
- second_precision=None):
+ def __init__(self, day_precision=None, second_precision=None):
"""Construct an INTERVAL.
Note that only DAY TO SECOND intervals are currently supported.
@@ -471,8 +484,10 @@ class INTERVAL(sqltypes.TypeEngine):
@classmethod
def _adapt_from_generic_interval(cls, interval):
- return INTERVAL(day_precision=interval.day_precision,
- second_precision=interval.second_precision)
+ return INTERVAL(
+ day_precision=interval.day_precision,
+ second_precision=interval.second_precision,
+ )
@property
def _type_affinity(self):
@@ -485,38 +500,40 @@ class ROWID(sqltypes.TypeEngine):
When used in a cast() or similar, generates ROWID.
"""
- __visit_name__ = 'ROWID'
+
+ __visit_name__ = "ROWID"
class _OracleBoolean(sqltypes.Boolean):
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
+
colspecs = {
sqltypes.Boolean: _OracleBoolean,
sqltypes.Interval: INTERVAL,
- sqltypes.DateTime: DATE
+ sqltypes.DateTime: DATE,
}
ischema_names = {
- 'VARCHAR2': VARCHAR,
- 'NVARCHAR2': NVARCHAR,
- 'CHAR': CHAR,
- 'DATE': DATE,
- 'NUMBER': NUMBER,
- 'BLOB': BLOB,
- 'BFILE': BFILE,
- 'CLOB': CLOB,
- 'NCLOB': NCLOB,
- 'TIMESTAMP': TIMESTAMP,
- 'TIMESTAMP WITH TIME ZONE': TIMESTAMP,
- 'INTERVAL DAY TO SECOND': INTERVAL,
- 'RAW': RAW,
- 'FLOAT': FLOAT,
- 'DOUBLE PRECISION': DOUBLE_PRECISION,
- 'LONG': LONG,
- 'BINARY_DOUBLE': BINARY_DOUBLE,
- 'BINARY_FLOAT': BINARY_FLOAT
+ "VARCHAR2": VARCHAR,
+ "NVARCHAR2": NVARCHAR,
+ "CHAR": CHAR,
+ "DATE": DATE,
+ "NUMBER": NUMBER,
+ "BLOB": BLOB,
+ "BFILE": BFILE,
+ "CLOB": CLOB,
+ "NCLOB": NCLOB,
+ "TIMESTAMP": TIMESTAMP,
+ "TIMESTAMP WITH TIME ZONE": TIMESTAMP,
+ "INTERVAL DAY TO SECOND": INTERVAL,
+ "RAW": RAW,
+ "FLOAT": FLOAT,
+ "DOUBLE PRECISION": DOUBLE_PRECISION,
+ "LONG": LONG,
+ "BINARY_DOUBLE": BINARY_DOUBLE,
+ "BINARY_FLOAT": BINARY_FLOAT,
}
@@ -540,12 +557,12 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
def visit_INTERVAL(self, type_, **kw):
return "INTERVAL DAY%s TO SECOND%s" % (
- type_.day_precision is not None and
- "(%d)" % type_.day_precision or
- "",
- type_.second_precision is not None and
- "(%d)" % type_.second_precision or
- "",
+ type_.day_precision is not None
+ and "(%d)" % type_.day_precision
+ or "",
+ type_.second_precision is not None
+ and "(%d)" % type_.second_precision
+ or "",
)
def visit_LONG(self, type_, **kw):
@@ -569,52 +586,53 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
def visit_FLOAT(self, type_, **kw):
# don't support conversion between decimal/binary
# precision yet
- kw['no_precision'] = True
+ kw["no_precision"] = True
return self._generate_numeric(type_, "FLOAT", **kw)
def visit_NUMBER(self, type_, **kw):
return self._generate_numeric(type_, "NUMBER", **kw)
def _generate_numeric(
- self, type_, name, precision=None,
- scale=None, no_precision=False, **kw):
+ self, type_, name, precision=None, scale=None, no_precision=False, **kw
+ ):
if precision is None:
precision = type_.precision
if scale is None:
- scale = getattr(type_, 'scale', None)
+ scale = getattr(type_, "scale", None)
if no_precision or precision is None:
return name
elif scale is None:
n = "%(name)s(%(precision)s)"
- return n % {'name': name, 'precision': precision}
+ return n % {"name": name, "precision": precision}
else:
n = "%(name)s(%(precision)s, %(scale)s)"
- return n % {'name': name, 'precision': precision, 'scale': scale}
+ return n % {"name": name, "precision": precision, "scale": scale}
def visit_string(self, type_, **kw):
return self.visit_VARCHAR2(type_, **kw)
def visit_VARCHAR2(self, type_, **kw):
- return self._visit_varchar(type_, '', '2')
+ return self._visit_varchar(type_, "", "2")
def visit_NVARCHAR2(self, type_, **kw):
- return self._visit_varchar(type_, 'N', '2')
+ return self._visit_varchar(type_, "N", "2")
+
visit_NVARCHAR = visit_NVARCHAR2
def visit_VARCHAR(self, type_, **kw):
- return self._visit_varchar(type_, '', '')
+ return self._visit_varchar(type_, "", "")
def _visit_varchar(self, type_, n, num):
if not type_.length:
- return "%(n)sVARCHAR%(two)s" % {'two': num, 'n': n}
+ return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n}
elif not n and self.dialect._supports_char_length:
varchar = "VARCHAR%(two)s(%(length)s CHAR)"
- return varchar % {'length': type_.length, 'two': num}
+ return varchar % {"length": type_.length, "two": num}
else:
varchar = "%(n)sVARCHAR%(two)s(%(length)s)"
- return varchar % {'length': type_.length, 'two': num, 'n': n}
+ return varchar % {"length": type_.length, "two": num, "n": n}
def visit_text(self, type_, **kw):
return self.visit_CLOB(type_, **kw)
@@ -636,7 +654,7 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
def visit_RAW(self, type_, **kw):
if type_.length:
- return "RAW(%(length)s)" % {'length': type_.length}
+ return "RAW(%(length)s)" % {"length": type_.length}
else:
return "RAW"
@@ -652,9 +670,7 @@ class OracleCompiler(compiler.SQLCompiler):
compound_keywords = util.update_copy(
compiler.SQLCompiler.compound_keywords,
- {
- expression.CompoundSelect.EXCEPT: 'MINUS'
- }
+ {expression.CompoundSelect.EXCEPT: "MINUS"},
)
def __init__(self, *args, **kwargs):
@@ -663,8 +679,10 @@ class OracleCompiler(compiler.SQLCompiler):
super(OracleCompiler, self).__init__(*args, **kwargs)
def visit_mod_binary(self, binary, operator, **kw):
- return "mod(%s, %s)" % (self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ return "mod(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
@@ -673,22 +691,22 @@ class OracleCompiler(compiler.SQLCompiler):
return "LENGTH" + self.function_argspec(fn, **kw)
def visit_match_op_binary(self, binary, operator, **kw):
- return "CONTAINS (%s, %s)" % (self.process(binary.left),
- self.process(binary.right))
+ return "CONTAINS (%s, %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
def visit_true(self, expr, **kw):
- return '1'
+ return "1"
def visit_false(self, expr, **kw):
- return '0'
+ return "0"
def get_cte_preamble(self, recursive):
return "WITH"
def get_select_hint_text(self, byfroms):
- return " ".join(
- "/*+ %s */" % text for table, text in byfroms.items()
- )
+ return " ".join("/*+ %s */" % text for table, text in byfroms.items())
def function_argspec(self, fn, **kw):
if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS:
@@ -709,13 +727,16 @@ class OracleCompiler(compiler.SQLCompiler):
if self.dialect.use_ansi:
return compiler.SQLCompiler.visit_join(self, join, **kwargs)
else:
- kwargs['asfrom'] = True
+ kwargs["asfrom"] = True
if isinstance(join.right, expression.FromGrouping):
right = join.right.element
else:
right = join.right
- return self.process(join.left, **kwargs) + \
- ", " + self.process(right, **kwargs)
+ return (
+ self.process(join.left, **kwargs)
+ + ", "
+ + self.process(right, **kwargs)
+ )
def _get_nonansi_join_whereclause(self, froms):
clauses = []
@@ -727,14 +748,20 @@ class OracleCompiler(compiler.SQLCompiler):
# the join condition in the WHERE clause" - that is,
# unconditionally regardless of operator or the other side
def visit_binary(binary):
- if isinstance(binary.left, expression.ColumnClause) \
- and join.right.is_derived_from(binary.left.table):
+ if isinstance(
+ binary.left, expression.ColumnClause
+ ) and join.right.is_derived_from(binary.left.table):
binary.left = _OuterJoinColumn(binary.left)
- elif isinstance(binary.right, expression.ColumnClause) \
- and join.right.is_derived_from(binary.right.table):
+ elif isinstance(
+ binary.right, expression.ColumnClause
+ ) and join.right.is_derived_from(binary.right.table):
binary.right = _OuterJoinColumn(binary.right)
- clauses.append(visitors.cloned_traverse(
- join.onclause, {}, {'binary': visit_binary}))
+
+ clauses.append(
+ visitors.cloned_traverse(
+ join.onclause, {}, {"binary": visit_binary}
+ )
+ )
else:
clauses.append(join.onclause)
@@ -757,8 +784,9 @@ class OracleCompiler(compiler.SQLCompiler):
return self.process(vc.column, **kw) + "(+)"
def visit_sequence(self, seq, **kw):
- return (self.dialect.identifier_preparer.format_sequence(seq) +
- ".nextval")
+ return (
+ self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
+ )
def get_render_as_alias_suffix(self, alias_name_text):
"""Oracle doesn't like ``FROM table AS alias``"""
@@ -770,7 +798,8 @@ class OracleCompiler(compiler.SQLCompiler):
binds = []
for i, column in enumerate(
- expression._select_iterables(returning_cols)):
+ expression._select_iterables(returning_cols)
+ ):
if column.type._has_column_expression:
col_expr = column.type.column_expression(column)
else:
@@ -779,19 +808,22 @@ class OracleCompiler(compiler.SQLCompiler):
outparam = sql.outparam("ret_%d" % i, type_=column.type)
self.binds[outparam.key] = outparam
binds.append(
- self.bindparam_string(self._truncate_bindparam(outparam)))
- columns.append(
- self.process(col_expr, within_columns_clause=False))
+ self.bindparam_string(self._truncate_bindparam(outparam))
+ )
+ columns.append(self.process(col_expr, within_columns_clause=False))
self._add_to_result_map(
- getattr(col_expr, 'name', col_expr.anon_label),
- getattr(col_expr, 'name', col_expr.anon_label),
- (column, getattr(column, 'name', None),
- getattr(column, 'key', None)),
- column.type
+ getattr(col_expr, "name", col_expr.anon_label),
+ getattr(col_expr, "name", col_expr.anon_label),
+ (
+ column,
+ getattr(column, "name", None),
+ getattr(column, "key", None),
+ ),
+ column.type,
)
- return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
+ return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds)
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a
@@ -804,10 +836,11 @@ class OracleCompiler(compiler.SQLCompiler):
so tries to wrap it in a subquery with ``rownum`` criterion.
"""
- if not getattr(select, '_oracle_visit', None):
+ if not getattr(select, "_oracle_visit", None):
if not self.dialect.use_ansi:
froms = self._display_froms_for_select(
- select, kwargs.get('asfrom', False))
+ select, kwargs.get("asfrom", False)
+ )
whereclause = self._get_nonansi_join_whereclause(froms)
if whereclause is not None:
select = select.where(whereclause)
@@ -828,18 +861,20 @@ class OracleCompiler(compiler.SQLCompiler):
# Outer select and "ROWNUM as ora_rn" can be dropped if
# limit=0
- kwargs['select_wraps_for'] = select
+ kwargs["select_wraps_for"] = select
select = select._generate()
select._oracle_visit = True
# Wrap the middle select and add the hint
limitselect = sql.select([c for c in select.c])
- if limit_clause is not None and \
- self.dialect.optimize_limits and \
- select._simple_int_limit:
+ if (
+ limit_clause is not None
+ and self.dialect.optimize_limits
+ and select._simple_int_limit
+ ):
limitselect = limitselect.prefix_with(
- "/*+ FIRST_ROWS(%d) */" %
- select._limit)
+ "/*+ FIRST_ROWS(%d) */" % select._limit
+ )
limitselect._oracle_visit = True
limitselect._is_wrapper = True
@@ -855,8 +890,8 @@ class OracleCompiler(compiler.SQLCompiler):
adapter = sql_util.ClauseAdapter(select)
for_update.of = [
- adapter.traverse(elem)
- for elem in for_update.of]
+ adapter.traverse(elem) for elem in for_update.of
+ ]
# If needed, add the limiting clause
if limit_clause is not None:
@@ -873,7 +908,8 @@ class OracleCompiler(compiler.SQLCompiler):
if offset_clause is not None:
max_row = max_row + offset_clause
limitselect.append_whereclause(
- sql.literal_column("ROWNUM") <= max_row)
+ sql.literal_column("ROWNUM") <= max_row
+ )
# If needed, add the ora_rn, and wrap again with offset.
if offset_clause is None:
@@ -881,12 +917,14 @@ class OracleCompiler(compiler.SQLCompiler):
select = limitselect
else:
limitselect = limitselect.column(
- sql.literal_column("ROWNUM").label("ora_rn"))
+ sql.literal_column("ROWNUM").label("ora_rn")
+ )
limitselect._oracle_visit = True
limitselect._is_wrapper = True
offsetselect = sql.select(
- [c for c in limitselect.c if c.key != 'ora_rn'])
+ [c for c in limitselect.c if c.key != "ora_rn"]
+ )
offsetselect._oracle_visit = True
offsetselect._is_wrapper = True
@@ -897,9 +935,11 @@ class OracleCompiler(compiler.SQLCompiler):
if not self.dialect.use_binds_for_limits:
offset_clause = sql.literal_column(
- "%d" % select._offset)
+ "%d" % select._offset
+ )
offsetselect.append_whereclause(
- sql.literal_column("ora_rn") > offset_clause)
+ sql.literal_column("ora_rn") > offset_clause
+ )
offsetselect._for_update_arg = for_update
select = offsetselect
@@ -910,18 +950,17 @@ class OracleCompiler(compiler.SQLCompiler):
return ""
def visit_empty_set_expr(self, type_):
- return 'SELECT 1 FROM DUAL WHERE 1!=1'
+ return "SELECT 1 FROM DUAL WHERE 1!=1"
def for_update_clause(self, select, **kw):
if self.is_subquery():
return ""
- tmp = ' FOR UPDATE'
+ tmp = " FOR UPDATE"
if select._for_update_arg.of:
- tmp += ' OF ' + ', '.join(
- self.process(elem, **kw) for elem in
- select._for_update_arg.of
+ tmp += " OF " + ", ".join(
+ self.process(elem, **kw) for elem in select._for_update_arg.of
)
if select._for_update_arg.nowait:
@@ -933,7 +972,6 @@ class OracleCompiler(compiler.SQLCompiler):
class OracleDDLCompiler(compiler.DDLCompiler):
-
def define_constraint_cascades(self, constraint):
text = ""
if constraint.ondelete is not None:
@@ -947,7 +985,8 @@ class OracleDDLCompiler(compiler.DDLCompiler):
"Oracle does not contain native UPDATE CASCADE "
"functionality - onupdates will not be rendered for foreign "
"keys. Consider using deferrable=True, initially='deferred' "
- "or triggers.")
+ "or triggers."
+ )
return text
@@ -958,75 +997,79 @@ class OracleDDLCompiler(compiler.DDLCompiler):
text = "CREATE "
if index.unique:
text += "UNIQUE "
- if index.dialect_options['oracle']['bitmap']:
+ if index.dialect_options["oracle"]["bitmap"]:
text += "BITMAP "
text += "INDEX %s ON %s (%s)" % (
self._prepared_index_name(index, include_schema=True),
preparer.format_table(index.table, use_schema=True),
- ', '.join(
+ ", ".join(
self.sql_compiler.process(
- expr,
- include_table=False, literal_binds=True)
- for expr in index.expressions)
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
)
- if index.dialect_options['oracle']['compress'] is not False:
- if index.dialect_options['oracle']['compress'] is True:
+ if index.dialect_options["oracle"]["compress"] is not False:
+ if index.dialect_options["oracle"]["compress"] is True:
text += " COMPRESS"
else:
text += " COMPRESS %d" % (
- index.dialect_options['oracle']['compress']
+ index.dialect_options["oracle"]["compress"]
)
return text
def post_create_table(self, table):
table_opts = []
- opts = table.dialect_options['oracle']
+ opts = table.dialect_options["oracle"]
- if opts['on_commit']:
- on_commit_options = opts['on_commit'].replace("_", " ").upper()
- table_opts.append('\n ON COMMIT %s' % on_commit_options)
+ if opts["on_commit"]:
+ on_commit_options = opts["on_commit"].replace("_", " ").upper()
+ table_opts.append("\n ON COMMIT %s" % on_commit_options)
- if opts['compress']:
- if opts['compress'] is True:
+ if opts["compress"]:
+ if opts["compress"] is True:
table_opts.append("\n COMPRESS")
else:
- table_opts.append("\n COMPRESS FOR %s" % (
- opts['compress']
- ))
+ table_opts.append("\n COMPRESS FOR %s" % (opts["compress"]))
- return ''.join(table_opts)
+ return "".join(table_opts)
class OracleIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = {x.lower() for x in RESERVED_WORDS}
- illegal_initial_characters = {str(dig) for dig in range(0, 10)} \
- .union(["_", "$"])
+ illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union(
+ ["_", "$"]
+ )
def _bindparam_requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
- return (lc_value in self.reserved_words
- or value[0] in self.illegal_initial_characters
- or not self.legal_characters.match(util.text_type(value))
- )
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ )
def format_savepoint(self, savepoint):
- name = savepoint.ident.lstrip('_')
- return super(
- OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
+ name = savepoint.ident.lstrip("_")
+ return super(OracleIdentifierPreparer, self).format_savepoint(
+ savepoint, name
+ )
class OracleExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
return self._execute_scalar(
- "SELECT " +
- self.dialect.identifier_preparer.format_sequence(seq) +
- ".nextval FROM DUAL", type_)
+ "SELECT "
+ + self.dialect.identifier_preparer.format_sequence(seq)
+ + ".nextval FROM DUAL",
+ type_,
+ )
class OracleDialect(default.DefaultDialect):
- name = 'oracle'
+ name = "oracle"
supports_alter = True
supports_unicode_statements = False
supports_unicode_binds = False
@@ -1039,7 +1082,7 @@ class OracleDialect(default.DefaultDialect):
sequences_optional = False
postfetch_lastrowid = False
- default_paramstyle = 'named'
+ default_paramstyle = "named"
colspecs = colspecs
ischema_names = ischema_names
requires_name_normalize = True
@@ -1054,29 +1097,27 @@ class OracleDialect(default.DefaultDialect):
preparer = OracleIdentifierPreparer
execution_ctx_cls = OracleExecutionContext
- reflection_options = ('oracle_resolve_synonyms', )
+ reflection_options = ("oracle_resolve_synonyms",)
_use_nchar_for_unicode = False
construct_arguments = [
- (sa_schema.Table, {
- "resolve_synonyms": False,
- "on_commit": None,
- "compress": False
- }),
- (sa_schema.Index, {
- "bitmap": False,
- "compress": False
- })
+ (
+ sa_schema.Table,
+ {"resolve_synonyms": False, "on_commit": None, "compress": False},
+ ),
+ (sa_schema.Index, {"bitmap": False, "compress": False}),
]
- def __init__(self,
- use_ansi=True,
- optimize_limits=False,
- use_binds_for_limits=True,
- use_nchar_for_unicode=False,
- exclude_tablespaces=('SYSTEM', 'SYSAUX', ),
- **kwargs):
+ def __init__(
+ self,
+ use_ansi=True,
+ optimize_limits=False,
+ use_binds_for_limits=True,
+ use_nchar_for_unicode=False,
+ exclude_tablespaces=("SYSTEM", "SYSAUX"),
+ **kwargs
+ ):
default.DefaultDialect.__init__(self, **kwargs)
self._use_nchar_for_unicode = use_nchar_for_unicode
self.use_ansi = use_ansi
@@ -1087,8 +1128,7 @@ class OracleDialect(default.DefaultDialect):
def initialize(self, connection):
super(OracleDialect, self).initialize(connection)
self.implicit_returning = self.__dict__.get(
- 'implicit_returning',
- self.server_version_info > (10, )
+ "implicit_returning", self.server_version_info > (10,)
)
if self._is_oracle_8:
@@ -1098,18 +1138,15 @@ class OracleDialect(default.DefaultDialect):
@property
def _is_oracle_8(self):
- return self.server_version_info and \
- self.server_version_info < (9, )
+ return self.server_version_info and self.server_version_info < (9,)
@property
def _supports_table_compression(self):
- return self.server_version_info and \
- self.server_version_info >= (10, 1, )
+ return self.server_version_info and self.server_version_info >= (10, 1)
@property
def _supports_table_compress_for(self):
- return self.server_version_info and \
- self.server_version_info >= (11, )
+ return self.server_version_info and self.server_version_info >= (11,)
@property
def _supports_char_length(self):
@@ -1123,31 +1160,38 @@ class OracleDialect(default.DefaultDialect):
additional_tests = [
expression.cast(
expression.literal_column("'test nvarchar2 returns'"),
- sqltypes.NVARCHAR(60)
- ),
+ sqltypes.NVARCHAR(60),
+ )
]
return super(OracleDialect, self)._check_unicode_returns(
- connection, additional_tests)
+ connection, additional_tests
+ )
def has_table(self, connection, table_name, schema=None):
if not schema:
schema = self.default_schema_name
cursor = connection.execute(
- sql.text("SELECT table_name FROM all_tables "
- "WHERE table_name = :name AND owner = :schema_name"),
+ sql.text(
+ "SELECT table_name FROM all_tables "
+ "WHERE table_name = :name AND owner = :schema_name"
+ ),
name=self.denormalize_name(table_name),
- schema_name=self.denormalize_name(schema))
+ schema_name=self.denormalize_name(schema),
+ )
return cursor.first() is not None
def has_sequence(self, connection, sequence_name, schema=None):
if not schema:
schema = self.default_schema_name
cursor = connection.execute(
- sql.text("SELECT sequence_name FROM all_sequences "
- "WHERE sequence_name = :name AND "
- "sequence_owner = :schema_name"),
+ sql.text(
+ "SELECT sequence_name FROM all_sequences "
+ "WHERE sequence_name = :name AND "
+ "sequence_owner = :schema_name"
+ ),
name=self.denormalize_name(sequence_name),
- schema_name=self.denormalize_name(schema))
+ schema_name=self.denormalize_name(schema),
+ )
return cursor.first() is not None
def normalize_name(self, name):
@@ -1156,8 +1200,9 @@ class OracleDialect(default.DefaultDialect):
if util.py2k:
if isinstance(name, str):
name = name.decode(self.encoding)
- if name.upper() == name and not \
- self.identifier_preparer._requires_quotes(name.lower()):
+ if name.upper() == name and not self.identifier_preparer._requires_quotes(
+ name.lower()
+ ):
return name.lower()
elif name.lower() == name:
return quoted_name(name, quote=True)
@@ -1167,8 +1212,9 @@ class OracleDialect(default.DefaultDialect):
def denormalize_name(self, name):
if name is None:
return None
- elif name.lower() == name and not \
- self.identifier_preparer._requires_quotes(name.lower()):
+ elif name.lower() == name and not self.identifier_preparer._requires_quotes(
+ name.lower()
+ ):
name = name.upper()
if util.py2k:
if not self.supports_unicode_binds:
@@ -1179,10 +1225,16 @@ class OracleDialect(default.DefaultDialect):
def _get_default_schema_name(self, connection):
return self.normalize_name(
- connection.execute('SELECT USER FROM DUAL').scalar())
+ connection.execute("SELECT USER FROM DUAL").scalar()
+ )
- def _resolve_synonym(self, connection, desired_owner=None,
- desired_synonym=None, desired_table=None):
+ def _resolve_synonym(
+ self,
+ connection,
+ desired_owner=None,
+ desired_synonym=None,
+ desired_table=None,
+ ):
"""search for a local synonym matching the given desired owner/name.
if desired_owner is None, attempts to locate a distinct owner.
@@ -1191,19 +1243,21 @@ class OracleDialect(default.DefaultDialect):
found.
"""
- q = "SELECT owner, table_owner, table_name, db_link, "\
+ q = (
+ "SELECT owner, table_owner, table_name, db_link, "
"synonym_name FROM all_synonyms WHERE "
+ )
clauses = []
params = {}
if desired_synonym:
clauses.append("synonym_name = :synonym_name")
- params['synonym_name'] = desired_synonym
+ params["synonym_name"] = desired_synonym
if desired_owner:
clauses.append("owner = :desired_owner")
- params['desired_owner'] = desired_owner
+ params["desired_owner"] = desired_owner
if desired_table:
clauses.append("table_name = :tname")
- params['tname'] = desired_table
+ params["tname"] = desired_table
q += " AND ".join(clauses)
@@ -1211,8 +1265,12 @@ class OracleDialect(default.DefaultDialect):
if desired_owner:
row = result.first()
if row:
- return (row['table_name'], row['table_owner'],
- row['db_link'], row['synonym_name'])
+ return (
+ row["table_name"],
+ row["table_owner"],
+ row["db_link"],
+ row["synonym_name"],
+ )
else:
return None, None, None, None
else:
@@ -1220,23 +1278,35 @@ class OracleDialect(default.DefaultDialect):
if len(rows) > 1:
raise AssertionError(
"There are multiple tables visible to the schema, you "
- "must specify owner")
+ "must specify owner"
+ )
elif len(rows) == 1:
row = rows[0]
- return (row['table_name'], row['table_owner'],
- row['db_link'], row['synonym_name'])
+ return (
+ row["table_name"],
+ row["table_owner"],
+ row["db_link"],
+ row["synonym_name"],
+ )
else:
return None, None, None, None
@reflection.cache
- def _prepare_reflection_args(self, connection, table_name, schema=None,
- resolve_synonyms=False, dblink='', **kw):
+ def _prepare_reflection_args(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
if resolve_synonyms:
actual_name, owner, dblink, synonym = self._resolve_synonym(
connection,
desired_owner=self.denormalize_name(schema),
- desired_synonym=self.denormalize_name(table_name)
+ desired_synonym=self.denormalize_name(table_name),
)
else:
actual_name, owner, dblink, synonym = None, None, None, None
@@ -1250,18 +1320,21 @@ class OracleDialect(default.DefaultDialect):
# will need to hear from more users if we are doing
# the right thing here. See [ticket:2619]
owner = connection.scalar(
- sql.text("SELECT username FROM user_db_links "
- "WHERE db_link=:link"), link=dblink)
+ sql.text(
+ "SELECT username FROM user_db_links " "WHERE db_link=:link"
+ ),
+ link=dblink,
+ )
dblink = "@" + dblink
elif not owner:
owner = self.denormalize_name(schema or self.default_schema_name)
- return (actual_name, owner, dblink or '', synonym)
+ return (actual_name, owner, dblink or "", synonym)
@reflection.cache
def get_schema_names(self, connection, **kw):
s = "SELECT username FROM all_users ORDER BY username"
- cursor = connection.execute(s,)
+ cursor = connection.execute(s)
return [self.normalize_name(row[0]) for row in cursor]
@reflection.cache
@@ -1276,14 +1349,12 @@ class OracleDialect(default.DefaultDialect):
if self.exclude_tablespaces:
sql_str += (
"nvl(tablespace_name, 'no tablespace') "
- "NOT IN (%s) AND " % (
- ', '.join(["'%s'" % ts for ts in self.exclude_tablespaces])
- )
+ "NOT IN (%s) AND "
+ % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
)
sql_str += (
- "OWNER = :owner "
- "AND IOT_NAME IS NULL "
- "AND DURATION IS NULL")
+ "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL"
+ )
cursor = connection.execute(sql.text(sql_str), owner=schema)
return [self.normalize_name(row[0]) for row in cursor]
@@ -1296,14 +1367,14 @@ class OracleDialect(default.DefaultDialect):
if self.exclude_tablespaces:
sql_str += (
"nvl(tablespace_name, 'no tablespace') "
- "NOT IN (%s) AND " % (
- ', '.join(["'%s'" % ts for ts in self.exclude_tablespaces])
- )
+ "NOT IN (%s) AND "
+ % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
)
sql_str += (
"OWNER = :owner "
"AND IOT_NAME IS NULL "
- "AND DURATION IS NOT NULL")
+ "AND DURATION IS NOT NULL"
+ )
cursor = connection.execute(sql.text(sql_str), owner=schema)
return [self.normalize_name(row[0]) for row in cursor]
@@ -1319,14 +1390,18 @@ class OracleDialect(default.DefaultDialect):
def get_table_options(self, connection, table_name, schema=None, **kw):
options = {}
- resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
- dblink = kw.get('dblink', '')
- info_cache = kw.get('info_cache')
-
- (table_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, table_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
params = {"table_name": table_name}
@@ -1336,14 +1411,16 @@ class OracleDialect(default.DefaultDialect):
if self._supports_table_compress_for:
columns.append("compress_for")
- text = "SELECT %(columns)s "\
- "FROM ALL_TABLES%(dblink)s "\
+ text = (
+ "SELECT %(columns)s "
+ "FROM ALL_TABLES%(dblink)s "
"WHERE table_name = :table_name"
+ )
if schema is not None:
- params['owner'] = schema
+ params["owner"] = schema
text += " AND owner = :owner "
- text = text % {'dblink': dblink, 'columns': ", ".join(columns)}
+ text = text % {"dblink": dblink, "columns": ", ".join(columns)}
result = connection.execute(sql.text(text), **params)
@@ -1353,9 +1430,9 @@ class OracleDialect(default.DefaultDialect):
if row:
if "compression" in row and enabled.get(row.compression, False):
if "compress_for" in row:
- options['oracle_compress'] = row.compress_for
+ options["oracle_compress"] = row.compress_for
else:
- options['oracle_compress'] = True
+ options["oracle_compress"] = True
return options
@@ -1371,19 +1448,23 @@ class OracleDialect(default.DefaultDialect):
"""
- resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
- dblink = kw.get('dblink', '')
- info_cache = kw.get('info_cache')
-
- (table_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, table_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
columns = []
if self._supports_char_length:
- char_length_col = 'char_length'
+ char_length_col = "char_length"
else:
- char_length_col = 'data_length'
+ char_length_col = "data_length"
params = {"table_name": table_name}
text = """
@@ -1398,10 +1479,10 @@ class OracleDialect(default.DefaultDialect):
WHERE col.table_name = :table_name
"""
if schema is not None:
- params['owner'] = schema
+ params["owner"] = schema
text += " AND col.owner = :owner "
text += " ORDER BY col.column_id"
- text = text % {'dblink': dblink, 'char_length_col': char_length_col}
+ text = text % {"dblink": dblink, "char_length_col": char_length_col}
c = connection.execute(sql.text(text), **params)
@@ -1412,54 +1493,67 @@ class OracleDialect(default.DefaultDialect):
length = row[2]
precision = row[3]
scale = row[4]
- nullable = row[5] == 'Y'
+ nullable = row[5] == "Y"
default = row[6]
comment = row[7]
- if coltype == 'NUMBER':
+ if coltype == "NUMBER":
if precision is None and scale == 0:
coltype = INTEGER()
else:
coltype = NUMBER(precision, scale)
- elif coltype == 'FLOAT':
+ elif coltype == "FLOAT":
# TODO: support "precision" here as "binary_precision"
coltype = FLOAT()
- elif coltype in ('VARCHAR2', 'NVARCHAR2', 'CHAR'):
+ elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR"):
coltype = self.ischema_names.get(coltype)(length)
- elif 'WITH TIME ZONE' in coltype:
+ elif "WITH TIME ZONE" in coltype:
coltype = TIMESTAMP(timezone=True)
else:
- coltype = re.sub(r'\(\d+\)', '', coltype)
+ coltype = re.sub(r"\(\d+\)", "", coltype)
try:
coltype = self.ischema_names[coltype]
except KeyError:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (coltype, colname))
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (coltype, colname)
+ )
coltype = sqltypes.NULLTYPE
cdict = {
- 'name': colname,
- 'type': coltype,
- 'nullable': nullable,
- 'default': default,
- 'autoincrement': 'auto',
- 'comment': comment,
+ "name": colname,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": "auto",
+ "comment": comment,
}
if orig_colname.lower() == orig_colname:
- cdict['quote'] = True
+ cdict["quote"] = True
columns.append(cdict)
return columns
@reflection.cache
- def get_table_comment(self, connection, table_name, schema=None,
- resolve_synonyms=False, dblink='', **kw):
-
- info_cache = kw.get('info_cache')
- (table_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, table_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
+ def get_table_comment(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+
+ info_cache = kw.get("info_cache")
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
COMMENT_SQL = """
SELECT comments
@@ -1471,67 +1565,90 @@ class OracleDialect(default.DefaultDialect):
return {"text": c.scalar()}
@reflection.cache
- def get_indexes(self, connection, table_name, schema=None,
- resolve_synonyms=False, dblink='', **kw):
-
- info_cache = kw.get('info_cache')
- (table_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, table_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
+ def get_indexes(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+
+ info_cache = kw.get("info_cache")
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
indexes = []
- params = {'table_name': table_name}
- text = \
- "SELECT a.index_name, a.column_name, "\
- "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "\
- "\nFROM ALL_IND_COLUMNS%(dblink)s a, "\
- "\nALL_INDEXES%(dblink)s b "\
- "\nWHERE "\
- "\na.index_name = b.index_name "\
- "\nAND a.table_owner = b.table_owner "\
- "\nAND a.table_name = b.table_name "\
+ params = {"table_name": table_name}
+ text = (
+ "SELECT a.index_name, a.column_name, "
+ "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "
+ "\nFROM ALL_IND_COLUMNS%(dblink)s a, "
+ "\nALL_INDEXES%(dblink)s b "
+ "\nWHERE "
+ "\na.index_name = b.index_name "
+ "\nAND a.table_owner = b.table_owner "
+ "\nAND a.table_name = b.table_name "
"\nAND a.table_name = :table_name "
+ )
if schema is not None:
- params['schema'] = schema
+ params["schema"] = schema
text += "AND a.table_owner = :schema "
text += "ORDER BY a.index_name, a.column_position"
- text = text % {'dblink': dblink}
+ text = text % {"dblink": dblink}
q = sql.text(text)
rp = connection.execute(q, **params)
indexes = []
last_index_name = None
pk_constraint = self.get_pk_constraint(
- connection, table_name, schema, resolve_synonyms=resolve_synonyms,
- dblink=dblink, info_cache=kw.get('info_cache'))
- pkeys = pk_constraint['constrained_columns']
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms=resolve_synonyms,
+ dblink=dblink,
+ info_cache=kw.get("info_cache"),
+ )
+ pkeys = pk_constraint["constrained_columns"]
uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
enabled = dict(DISABLED=False, ENABLED=True)
- oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE)
+ oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
index = None
for rset in rp:
if rset.index_name != last_index_name:
- index = dict(name=self.normalize_name(rset.index_name),
- column_names=[], dialect_options={})
+ index = dict(
+ name=self.normalize_name(rset.index_name),
+ column_names=[],
+ dialect_options={},
+ )
indexes.append(index)
- index['unique'] = uniqueness.get(rset.uniqueness, False)
+ index["unique"] = uniqueness.get(rset.uniqueness, False)
- if rset.index_type in ('BITMAP', 'FUNCTION-BASED BITMAP'):
- index['dialect_options']['oracle_bitmap'] = True
+ if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"):
+ index["dialect_options"]["oracle_bitmap"] = True
if enabled.get(rset.compression, False):
- index['dialect_options']['oracle_compress'] = rset.prefix_length
+ index["dialect_options"][
+ "oracle_compress"
+ ] = rset.prefix_length
# filter out Oracle SYS_NC names. could also do an outer join
# to the all_tab_columns table and check for real col names there.
if not oracle_sys_col.match(rset.column_name):
- index['column_names'].append(
- self.normalize_name(rset.column_name))
+ index["column_names"].append(
+ self.normalize_name(rset.column_name)
+ )
last_index_name = rset.index_name
def upper_name_set(names):
@@ -1539,18 +1656,21 @@ class OracleDialect(default.DefaultDialect):
pk_names = upper_name_set(pkeys)
if pk_names:
+
def is_pk_index(index):
# don't include the primary key index
- return upper_name_set(index['column_names']) == pk_names
+ return upper_name_set(index["column_names"]) == pk_names
+
indexes = [idx for idx in indexes if not is_pk_index(idx)]
return indexes
@reflection.cache
- def _get_constraint_data(self, connection, table_name, schema=None,
- dblink='', **kw):
+ def _get_constraint_data(
+ self, connection, table_name, schema=None, dblink="", **kw
+ ):
- params = {'table_name': table_name}
+ params = {"table_name": table_name}
text = (
"SELECT"
@@ -1572,7 +1692,7 @@ class OracleDialect(default.DefaultDialect):
)
if schema is not None:
- params['owner'] = schema
+ params["owner"] = schema
text += "\nAND ac.owner = :owner"
text += (
@@ -1584,35 +1704,49 @@ class OracleDialect(default.DefaultDialect):
"\nORDER BY ac.constraint_name, loc.position"
)
- text = text % {'dblink': dblink}
+ text = text % {"dblink": dblink}
rp = connection.execute(sql.text(text), **params)
constraint_data = rp.fetchall()
return constraint_data
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
- dblink = kw.get('dblink', '')
- info_cache = kw.get('info_cache')
-
- (table_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, table_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
pkeys = []
constraint_name = None
constraint_data = self._get_constraint_data(
- connection, table_name, schema, dblink,
- info_cache=kw.get('info_cache'))
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
for row in constraint_data:
- (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
- row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
- if cons_type == 'P':
+ (
+ cons_name,
+ cons_type,
+ local_column,
+ remote_table,
+ remote_column,
+ remote_owner,
+ ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+ if cons_type == "P":
if constraint_name is None:
constraint_name = self.normalize_name(cons_name)
pkeys.append(local_column)
- return {'constrained_columns': pkeys, 'name': constraint_name}
+ return {"constrained_columns": pkeys, "name": constraint_name}
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
@@ -1626,74 +1760,94 @@ class OracleDialect(default.DefaultDialect):
"""
requested_schema = schema # to check later on
- resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
- dblink = kw.get('dblink', '')
- info_cache = kw.get('info_cache')
-
- (table_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, table_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
constraint_data = self._get_constraint_data(
- connection, table_name, schema, dblink,
- info_cache=kw.get('info_cache'))
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
def fkey_rec():
return {
- 'name': None,
- 'constrained_columns': [],
- 'referred_schema': None,
- 'referred_table': None,
- 'referred_columns': [],
- 'options': {},
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
+ "options": {},
}
fkeys = util.defaultdict(fkey_rec)
for row in constraint_data:
- (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
- row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+ (
+ cons_name,
+ cons_type,
+ local_column,
+ remote_table,
+ remote_column,
+ remote_owner,
+ ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
cons_name = self.normalize_name(cons_name)
- if cons_type == 'R':
+ if cons_type == "R":
if remote_table is None:
# ticket 363
util.warn(
- ("Got 'None' querying 'table_name' from "
- "all_cons_columns%(dblink)s - does the user have "
- "proper rights to the table?") % {'dblink': dblink})
+ (
+ "Got 'None' querying 'table_name' from "
+ "all_cons_columns%(dblink)s - does the user have "
+ "proper rights to the table?"
+ )
+ % {"dblink": dblink}
+ )
continue
rec = fkeys[cons_name]
- rec['name'] = cons_name
- local_cols, remote_cols = rec[
- 'constrained_columns'], rec['referred_columns']
+ rec["name"] = cons_name
+ local_cols, remote_cols = (
+ rec["constrained_columns"],
+ rec["referred_columns"],
+ )
- if not rec['referred_table']:
+ if not rec["referred_table"]:
if resolve_synonyms:
- ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \
- self._resolve_synonym(
- connection,
- desired_owner=self.denormalize_name(
- remote_owner),
- desired_table=self.denormalize_name(
- remote_table)
- )
+ ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym(
+ connection,
+ desired_owner=self.denormalize_name(remote_owner),
+ desired_table=self.denormalize_name(remote_table),
+ )
if ref_synonym:
remote_table = self.normalize_name(ref_synonym)
remote_owner = self.normalize_name(
- ref_remote_owner)
+ ref_remote_owner
+ )
- rec['referred_table'] = remote_table
+ rec["referred_table"] = remote_table
- if requested_schema is not None or \
- self.denormalize_name(remote_owner) != schema:
- rec['referred_schema'] = remote_owner
+ if (
+ requested_schema is not None
+ or self.denormalize_name(remote_owner) != schema
+ ):
+ rec["referred_schema"] = remote_owner
- if row[9] != 'NO ACTION':
- rec['options']['ondelete'] = row[9]
+ if row[9] != "NO ACTION":
+ rec["options"]["ondelete"] = row[9]
local_cols.append(local_column)
remote_cols.append(remote_column)
@@ -1701,54 +1855,82 @@ class OracleDialect(default.DefaultDialect):
return list(fkeys.values())
@reflection.cache
- def get_unique_constraints(self, connection, table_name, schema=None, **kw):
- resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
- dblink = kw.get('dblink', '')
- info_cache = kw.get('info_cache')
-
- (table_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, table_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
constraint_data = self._get_constraint_data(
- connection, table_name, schema, dblink,
- info_cache=kw.get('info_cache'))
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
- unique_keys = filter(lambda x: x[1] == 'U', constraint_data)
+ unique_keys = filter(lambda x: x[1] == "U", constraint_data)
uniques_group = groupby(unique_keys, lambda x: x[0])
- index_names = set([ix['name'] for ix in self.get_indexes(connection, table_name, schema=schema)])
+ index_names = set(
+ [
+ ix["name"]
+ for ix in self.get_indexes(
+ connection, table_name, schema=schema
+ )
+ ]
+ )
return [
{
- 'name': name,
- 'column_names': cols,
- 'duplicates_index': name if name in index_names else None
+ "name": name,
+ "column_names": cols,
+ "duplicates_index": name if name in index_names else None,
}
- for name, cols in
- [
+ for name, cols in [
[
self.normalize_name(i[0]),
- [self.normalize_name(x[2]) for x in i[1]]
- ] for i in uniques_group
+ [self.normalize_name(x[2]) for x in i[1]],
+ ]
+ for i in uniques_group
]
]
@reflection.cache
- def get_view_definition(self, connection, view_name, schema=None,
- resolve_synonyms=False, dblink='', **kw):
- info_cache = kw.get('info_cache')
- (view_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, view_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
-
- params = {'view_name': view_name}
+ def get_view_definition(
+ self,
+ connection,
+ view_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+ info_cache = kw.get("info_cache")
+ (view_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ view_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ params = {"view_name": view_name}
text = "SELECT text FROM all_views WHERE view_name=:view_name"
if schema is not None:
text += " AND owner = :schema"
- params['schema'] = schema
+ params["schema"] = schema
rp = connection.execute(sql.text(text), **params).scalar()
if rp:
@@ -1759,34 +1941,41 @@ class OracleDialect(default.DefaultDialect):
return None
@reflection.cache
- def get_check_constraints(self, connection, table_name, schema=None,
- include_all=False, **kw):
- resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
- dblink = kw.get('dblink', '')
- info_cache = kw.get('info_cache')
-
- (table_name, schema, dblink, synonym) = \
- self._prepare_reflection_args(connection, table_name, schema,
- resolve_synonyms, dblink,
- info_cache=info_cache)
+ def get_check_constraints(
+ self, connection, table_name, schema=None, include_all=False, **kw
+ ):
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
constraint_data = self._get_constraint_data(
- connection, table_name, schema, dblink,
- info_cache=kw.get('info_cache'))
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
- check_constraints = filter(lambda x: x[1] == 'C', constraint_data)
+ check_constraints = filter(lambda x: x[1] == "C", constraint_data)
return [
- {
- 'name': self.normalize_name(cons[0]),
- 'sqltext': cons[8],
- }
- for cons in check_constraints if include_all or
- not re.match(r'..+?. IS NOT NULL$', cons[8])]
+ {"name": self.normalize_name(cons[0]), "sqltext": cons[8]}
+ for cons in check_constraints
+ if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8])
+ ]
class _OuterJoinColumn(sql.ClauseElement):
- __visit_name__ = 'outer_join_column'
+ __visit_name__ = "outer_join_column"
def __init__(self, column):
self.column = column
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index a00e7d95e..91534c0da 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -296,16 +296,13 @@ class _OracleInteger(sqltypes.Integer):
def _cx_oracle_var(self, dialect, cursor):
cx_Oracle = dialect.dbapi
return cursor.var(
- cx_Oracle.STRING,
- 255,
- arraysize=cursor.arraysize,
- outconverter=int
+ cx_Oracle.STRING, 255, arraysize=cursor.arraysize, outconverter=int
)
def _cx_oracle_outputtypehandler(self, dialect):
- def handler(cursor, name,
- default_type, size, precision, scale):
+ def handler(cursor, name, default_type, size, precision, scale):
return self._cx_oracle_var(dialect, cursor)
+
return handler
@@ -317,7 +314,8 @@ class _OracleNumeric(sqltypes.Numeric):
return None
elif self.asdecimal:
processor = processors.to_decimal_processor_factory(
- decimal.Decimal, self._effective_decimal_return_scale)
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
def process(value):
if isinstance(value, (int, float)):
@@ -326,6 +324,7 @@ class _OracleNumeric(sqltypes.Numeric):
return float(value)
else:
return value
+
return process
else:
return processors.to_float
@@ -383,9 +382,10 @@ class _OracleNumeric(sqltypes.Numeric):
type_ = cx_Oracle.NATIVE_FLOAT
return cursor.var(
- type_, 255,
+ type_,
+ 255,
arraysize=cursor.arraysize,
- outconverter=outconverter
+ outconverter=outconverter,
)
return handler
@@ -418,6 +418,7 @@ class _OracleDate(sqltypes.Date):
return value.date()
else:
return value
+
return process
@@ -467,6 +468,7 @@ class _OracleEnum(sqltypes.Enum):
def process(value):
raw_str = enum_proc(value)
return raw_str
+
return process
@@ -482,7 +484,8 @@ class _OracleBinary(sqltypes.LargeBinary):
return None
else:
return super(_OracleBinary, self).result_processor(
- dialect, coltype)
+ dialect, coltype
+ )
class _OracleInterval(oracle.INTERVAL):
@@ -503,14 +506,18 @@ class OracleCompiler_cx_oracle(OracleCompiler):
_oracle_cx_sql_compiler = True
def bindparam_string(self, name, **kw):
- quote = getattr(name, 'quote', None)
- if quote is True or quote is not False and \
- self.preparer._bindparam_requires_quotes(name):
- if kw.get('expanding', False):
+ quote = getattr(name, "quote", None)
+ if (
+ quote is True
+ or quote is not False
+ and self.preparer._bindparam_requires_quotes(name)
+ ):
+ if kw.get("expanding", False):
raise exc.CompileError(
"Can't use expanding feature with parameter name "
"%r on Oracle; it requires quoting which is not supported "
- "in this context." % name)
+ "in this context." % name
+ )
quoted_name = '"%s"' % name
self._quoted_bind_names[name] = quoted_name
return OracleCompiler.bindparam_string(self, quoted_name, **kw)
@@ -537,21 +544,22 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
if bindparam.isoutparam:
name = self.compiled.bind_names[bindparam]
type_impl = bindparam.type.dialect_impl(self.dialect)
- if hasattr(type_impl, '_cx_oracle_var'):
+ if hasattr(type_impl, "_cx_oracle_var"):
self.out_parameters[name] = type_impl._cx_oracle_var(
- self.dialect, self.cursor)
+ self.dialect, self.cursor
+ )
else:
dbtype = type_impl.get_dbapi_type(self.dialect.dbapi)
if dbtype is None:
raise exc.InvalidRequestError(
"Cannot create out parameter for parameter "
"%r - its type %r is not supported by"
- " cx_oracle" %
- (bindparam.key, bindparam.type)
+ " cx_oracle" % (bindparam.key, bindparam.type)
)
self.out_parameters[name] = self.cursor.var(dbtype)
- self.parameters[0][quoted_bind_names.get(name, name)] = \
- self.out_parameters[name]
+ self.parameters[0][
+ quoted_bind_names.get(name, name)
+ ] = self.out_parameters[name]
def _generate_cursor_outputtype_handler(self):
output_handlers = {}
@@ -559,8 +567,9 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
for (keyname, name, objects, type_) in self.compiled._result_columns:
handler = type_._cached_custom_processor(
self.dialect,
- 'cx_oracle_outputtypehandler',
- self._get_cx_oracle_type_handler)
+ "cx_oracle_outputtypehandler",
+ self._get_cx_oracle_type_handler,
+ )
if handler:
denormalized_name = self.dialect.denormalize_name(keyname)
@@ -569,16 +578,18 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
if output_handlers:
default_handler = self._dbapi_connection.outputtypehandler
- def output_type_handler(cursor, name, default_type,
- size, precision, scale):
+ def output_type_handler(
+ cursor, name, default_type, size, precision, scale
+ ):
if name in output_handlers:
return output_handlers[name](
- cursor, name,
- default_type, size, precision, scale)
+ cursor, name, default_type, size, precision, scale
+ )
else:
return default_handler(
cursor, name, default_type, size, precision, scale
)
+
self.cursor.outputtypehandler = output_type_handler
def _get_cx_oracle_type_handler(self, impl):
@@ -598,7 +609,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
self.set_input_sizes(
self.compiled._quoted_bind_names,
- include_types=self.dialect._include_setinputsizes
+ include_types=self.dialect._include_setinputsizes,
)
self._handle_out_parameters()
@@ -615,9 +626,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
def get_result_proxy(self):
if self.out_parameters and self.compiled.returning:
returning_params = [
- self.dialect._returningval(
- self.out_parameters["ret_%d" % i]
- )
+ self.dialect._returningval(self.out_parameters["ret_%d" % i])
for i in range(len(self.out_parameters))
]
return ReturningResultProxy(self, returning_params)
@@ -625,8 +634,10 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
result = _result.ResultProxy(self)
if self.out_parameters:
- if self.compiled_parameters is not None and \
- len(self.compiled_parameters) == 1:
+ if (
+ self.compiled_parameters is not None
+ and len(self.compiled_parameters) == 1
+ ):
result.out_parameters = out_parameters = {}
for bind, name in self.compiled.bind_names.items():
@@ -634,22 +645,24 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
type = bind.type
impl_type = type.dialect_impl(self.dialect)
dbapi_type = impl_type.get_dbapi_type(
- self.dialect.dbapi)
- result_processor = impl_type.\
- result_processor(self.dialect,
- dbapi_type)
+ self.dialect.dbapi
+ )
+ result_processor = impl_type.result_processor(
+ self.dialect, dbapi_type
+ )
if result_processor is not None:
- out_parameters[name] = \
- result_processor(
- self.dialect._paramval(
- self.out_parameters[name]
- ))
+ out_parameters[name] = result_processor(
+ self.dialect._paramval(
+ self.out_parameters[name]
+ )
+ )
else:
out_parameters[name] = self.dialect._paramval(
- self.out_parameters[name])
+ self.out_parameters[name]
+ )
else:
result.out_parameters = dict(
- (k, self._dialect._paramval(v))
+ (k, self._dialect._paramval(v))
for k, v in self.out_parameters.items()
)
@@ -667,14 +680,11 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy):
def _cursor_description(self):
returning = self.context.compiled.returning
return [
- (getattr(col, 'name', col.anon_label), None)
- for col in returning
+ (getattr(col, "name", col.anon_label), None) for col in returning
]
def _buffer_rows(self):
- return collections.deque(
- [tuple(self._returning_params)]
- )
+ return collections.deque([tuple(self._returning_params)])
class OracleDialect_cx_oracle(OracleDialect):
@@ -696,7 +706,6 @@ class OracleDialect_cx_oracle(OracleDialect):
oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE,
sqltypes.Integer: _OracleInteger,
oracle.NUMBER: _OracleNUMBER,
-
sqltypes.Date: _OracleDate,
sqltypes.LargeBinary: _OracleBinary,
sqltypes.Boolean: oracle._OracleBoolean,
@@ -707,7 +716,6 @@ class OracleDialect_cx_oracle(OracleDialect):
sqltypes.UnicodeText: _OracleUnicodeTextCLOB,
sqltypes.CHAR: _OracleChar,
sqltypes.Enum: _OracleEnum,
-
oracle.LONG: _OracleLong,
oracle.RAW: _OracleRaw,
sqltypes.Unicode: _OracleUnicodeStringCHAR,
@@ -721,13 +729,15 @@ class OracleDialect_cx_oracle(OracleDialect):
_cx_oracle_threaded = None
- def __init__(self,
- auto_convert_lobs=True,
- coerce_to_unicode=True,
- coerce_to_decimal=True,
- arraysize=50,
- threaded=None,
- **kwargs):
+ def __init__(
+ self,
+ auto_convert_lobs=True,
+ coerce_to_unicode=True,
+ coerce_to_decimal=True,
+ arraysize=50,
+ threaded=None,
+ **kwargs
+ ):
OracleDialect.__init__(self, **kwargs)
self.arraysize = arraysize
@@ -757,15 +767,23 @@ class OracleDialect_cx_oracle(OracleDialect):
self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version)
if self.cx_oracle_ver < (5, 2) and self.cx_oracle_ver > (0, 0, 0):
raise exc.InvalidRequestError(
- "cx_Oracle version 5.2 and above are supported")
+ "cx_Oracle version 5.2 and above are supported"
+ )
self._has_native_int = hasattr(cx_Oracle, "NATIVE_INT")
self._include_setinputsizes = {
- cx_Oracle.NCLOB, cx_Oracle.CLOB, cx_Oracle.LOB,
- cx_Oracle.NCHAR, cx_Oracle.FIXED_NCHAR,
- cx_Oracle.BLOB, cx_Oracle.FIXED_CHAR, cx_Oracle.TIMESTAMP,
- _OracleInteger, _OracleBINARY_FLOAT, _OracleBINARY_DOUBLE
+ cx_Oracle.NCLOB,
+ cx_Oracle.CLOB,
+ cx_Oracle.LOB,
+ cx_Oracle.NCHAR,
+ cx_Oracle.FIXED_NCHAR,
+ cx_Oracle.BLOB,
+ cx_Oracle.FIXED_CHAR,
+ cx_Oracle.TIMESTAMP,
+ _OracleInteger,
+ _OracleBINARY_FLOAT,
+ _OracleBINARY_DOUBLE,
}
self._paramval = lambda value: value.getvalue()
@@ -786,18 +804,19 @@ class OracleDialect_cx_oracle(OracleDialect):
else:
self._returningval = self._paramval
- self._is_cx_oracle_6 = self.cx_oracle_ver >= (6, )
+ self._is_cx_oracle_6 = self.cx_oracle_ver >= (6,)
def _pop_deprecated_kwargs(self, kwargs):
- auto_setinputsizes = kwargs.pop('auto_setinputsizes', None)
- exclude_setinputsizes = kwargs.pop('exclude_setinputsizes', None)
+ auto_setinputsizes = kwargs.pop("auto_setinputsizes", None)
+ exclude_setinputsizes = kwargs.pop("exclude_setinputsizes", None)
if auto_setinputsizes or exclude_setinputsizes:
util.warn_deprecated(
"auto_setinputsizes and exclude_setinputsizes are deprecated. "
"Modern cx_Oracle only requires that LOB types are part "
"of this behavior, and these parameters no longer have any "
- "effect.")
- allow_twophase = kwargs.pop('allow_twophase', None)
+ "effect."
+ )
+ allow_twophase = kwargs.pop("allow_twophase", None)
if allow_twophase is not None:
util.warn.deprecated(
"allow_twophase is deprecated. The cx_Oracle dialect no "
@@ -805,18 +824,16 @@ class OracleDialect_cx_oracle(OracleDialect):
)
def _parse_cx_oracle_ver(self, version):
- m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', version)
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
if m:
- return tuple(
- int(x)
- for x in m.group(1, 2, 3)
- if x is not None)
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
else:
return (0, 0, 0)
@classmethod
def dbapi(cls):
import cx_Oracle
+
return cx_Oracle
def initialize(self, connection):
@@ -835,15 +852,18 @@ class OracleDialect_cx_oracle(OracleDialect):
self._decimal_char = connection.scalar(
"select value from nls_session_parameters "
- "where parameter = 'NLS_NUMERIC_CHARACTERS'")[0]
- if self._decimal_char != '.':
+ "where parameter = 'NLS_NUMERIC_CHARACTERS'"
+ )[0]
+ if self._decimal_char != ".":
_detect_decimal = self._detect_decimal
_to_decimal = self._to_decimal
self._detect_decimal = lambda value: _detect_decimal(
- value.replace(self._decimal_char, "."))
+ value.replace(self._decimal_char, ".")
+ )
self._to_decimal = lambda value: _to_decimal(
- value.replace(self._decimal_char, "."))
+ value.replace(self._decimal_char, ".")
+ )
def _detect_decimal(self, value):
if "." in value:
@@ -862,13 +882,16 @@ class OracleDialect_cx_oracle(OracleDialect):
dialect = self
cx_Oracle = dialect.dbapi
- number_handler = _OracleNUMBER(asdecimal=True).\
- _cx_oracle_outputtypehandler(dialect)
- float_handler = _OracleNUMBER(asdecimal=False).\
- _cx_oracle_outputtypehandler(dialect)
+ number_handler = _OracleNUMBER(
+ asdecimal=True
+ )._cx_oracle_outputtypehandler(dialect)
+ float_handler = _OracleNUMBER(
+ asdecimal=False
+ )._cx_oracle_outputtypehandler(dialect)
- def output_type_handler(cursor, name, default_type,
- size, precision, scale):
+ def output_type_handler(
+ cursor, name, default_type, size, precision, scale
+ ):
if default_type == cx_Oracle.NUMBER:
if not dialect.coerce_to_decimal:
return None
@@ -879,7 +902,8 @@ class OracleDialect_cx_oracle(OracleDialect):
cx_Oracle.STRING,
255,
outconverter=dialect._detect_decimal,
- arraysize=cursor.arraysize)
+ arraysize=cursor.arraysize,
+ )
elif precision and scale > 0:
return number_handler(
cursor, name, default_type, size, precision, scale
@@ -890,43 +914,55 @@ class OracleDialect_cx_oracle(OracleDialect):
)
# allow all strings to come back natively as Unicode
- elif dialect.coerce_to_unicode and \
- default_type in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR):
+ elif dialect.coerce_to_unicode and default_type in (
+ cx_Oracle.STRING,
+ cx_Oracle.FIXED_CHAR,
+ ):
if compat.py2k:
outconverter = processors.to_unicode_processor_factory(
- dialect.encoding, None)
- return cursor.var(
- cx_Oracle.STRING, size, cursor.arraysize,
- outconverter=outconverter
+ dialect.encoding, None
)
- else:
return cursor.var(
- util.text_type, size, cursor.arraysize
+ cx_Oracle.STRING,
+ size,
+ cursor.arraysize,
+ outconverter=outconverter,
)
+ else:
+ return cursor.var(util.text_type, size, cursor.arraysize)
elif dialect.auto_convert_lobs and default_type in (
- cx_Oracle.CLOB, cx_Oracle.NCLOB
+ cx_Oracle.CLOB,
+ cx_Oracle.NCLOB,
):
if compat.py2k:
outconverter = processors.to_unicode_processor_factory(
- dialect.encoding, None)
+ dialect.encoding, None
+ )
return cursor.var(
- default_type, size, cursor.arraysize,
- outconverter=lambda value: outconverter(value.read())
+ default_type,
+ size,
+ cursor.arraysize,
+ outconverter=lambda value: outconverter(value.read()),
)
else:
return cursor.var(
- default_type, size, cursor.arraysize,
- outconverter=lambda value: value.read()
+ default_type,
+ size,
+ cursor.arraysize,
+ outconverter=lambda value: value.read(),
)
elif dialect.auto_convert_lobs and default_type in (
- cx_Oracle.BLOB,
+ cx_Oracle.BLOB,
):
return cursor.var(
- default_type, size, cursor.arraysize,
- outconverter=lambda value: value.read()
+ default_type,
+ size,
+ cursor.arraysize,
+ outconverter=lambda value: value.read(),
)
+
return output_type_handler
def on_connect(self):
@@ -941,16 +977,17 @@ class OracleDialect_cx_oracle(OracleDialect):
def create_connect_args(self, url):
opts = dict(url.query)
- for opt in ('use_ansi', 'auto_convert_lobs'):
+ for opt in ("use_ansi", "auto_convert_lobs"):
if opt in opts:
util.warn_deprecated(
"cx_oracle dialect option %r should only be passed to "
- "create_engine directly, not within the URL string" % opt)
+ "create_engine directly, not within the URL string" % opt
+ )
util.coerce_kw_type(opts, opt, bool)
setattr(self, opt, opts.pop(opt))
database = url.database
- service_name = opts.pop('service_name', None)
+ service_name = opts.pop("service_name", None)
if database or service_name:
# if we have a database, then we have a remote host
port = url.port
@@ -962,11 +999,12 @@ class OracleDialect_cx_oracle(OracleDialect):
if database and service_name:
raise exc.InvalidRequestError(
'"service_name" option shouldn\'t '
- 'be used with a "database" part of the url')
+ 'be used with a "database" part of the url'
+ )
if database:
- makedsn_kwargs = {'sid': database}
+ makedsn_kwargs = {"sid": database}
if service_name:
- makedsn_kwargs = {'service_name': service_name}
+ makedsn_kwargs = {"service_name": service_name}
dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs)
else:
@@ -974,11 +1012,11 @@ class OracleDialect_cx_oracle(OracleDialect):
dsn = url.host
if dsn is not None:
- opts['dsn'] = dsn
+ opts["dsn"] = dsn
if url.password is not None:
- opts['password'] = url.password
+ opts["password"] = url.password
if url.username is not None:
- opts['user'] = url.username
+ opts["user"] = url.username
if self._cx_oracle_threaded is not None:
opts.setdefault("threaded", self._cx_oracle_threaded)
@@ -995,28 +1033,24 @@ class OracleDialect_cx_oracle(OracleDialect):
else:
return value
- util.coerce_kw_type(opts, 'mode', convert_cx_oracle_constant)
- util.coerce_kw_type(opts, 'threaded', bool)
- util.coerce_kw_type(opts, 'events', bool)
- util.coerce_kw_type(opts, 'purity', convert_cx_oracle_constant)
+ util.coerce_kw_type(opts, "mode", convert_cx_oracle_constant)
+ util.coerce_kw_type(opts, "threaded", bool)
+ util.coerce_kw_type(opts, "events", bool)
+ util.coerce_kw_type(opts, "purity", convert_cx_oracle_constant)
return ([], opts)
def _get_server_version_info(self, connection):
- return tuple(
- int(x)
- for x in connection.connection.version.split('.')
- )
+ return tuple(int(x) for x in connection.connection.version.split("."))
def is_disconnect(self, e, connection, cursor):
error, = e.args
if isinstance(
- e,
- (self.dbapi.InterfaceError, self.dbapi.DatabaseError)
+ e, (self.dbapi.InterfaceError, self.dbapi.DatabaseError)
) and "not connected" in str(e):
return True
- if hasattr(error, 'code'):
+ if hasattr(error, "code"):
# ORA-00028: your session has been killed
# ORA-03114: not connected to ORACLE
# ORA-03113: end-of-file on communication channel
@@ -1052,22 +1086,25 @@ class OracleDialect_cx_oracle(OracleDialect):
def do_prepare_twophase(self, connection, xid):
result = connection.connection.prepare()
- connection.info['cx_oracle_prepared'] = result
+ connection.info["cx_oracle_prepared"] = result
- def do_rollback_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
self.do_rollback(connection.connection)
- def do_commit_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if not is_prepared:
self.do_commit(connection.connection)
else:
- oci_prepared = connection.info['cx_oracle_prepared']
+ oci_prepared = connection.info["cx_oracle_prepared"]
if oci_prepared:
self.do_commit(connection.connection)
def do_recover_twophase(self, connection):
- connection.info.pop('cx_oracle_prepared', None)
+ connection.info.pop("cx_oracle_prepared", None)
+
dialect = OracleDialect_cx_oracle
diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py
index aa2562573..0a365f8b0 100644
--- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py
+++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py
@@ -21,9 +21,11 @@ import re
from sqlalchemy import sql, types as sqltypes, util
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
-from sqlalchemy.dialects.oracle.base import (OracleCompiler,
- OracleDialect,
- OracleExecutionContext)
+from sqlalchemy.dialects.oracle.base import (
+ OracleCompiler,
+ OracleDialect,
+ OracleExecutionContext,
+)
from sqlalchemy.engine import result as _result
from sqlalchemy.sql import expression
import collections
@@ -32,92 +34,100 @@ SQLException = zxJDBC = None
class _ZxJDBCDate(sqltypes.Date):
-
def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return None
else:
return value.date()
+
return process
class _ZxJDBCNumeric(sqltypes.Numeric):
-
def result_processor(self, dialect, coltype):
# XXX: does the dialect return Decimal or not???
# if it does (in all cases), we could use a None processor as well as
# the to_float generic processor
if self.asdecimal:
+
def process(value):
if isinstance(value, decimal.Decimal):
return value
else:
return decimal.Decimal(str(value))
+
else:
+
def process(value):
if isinstance(value, decimal.Decimal):
return float(value)
else:
return value
+
return process
class OracleCompiler_zxjdbc(OracleCompiler):
-
def returning_clause(self, stmt, returning_cols):
self.returning_cols = list(
- expression._select_iterables(returning_cols))
+ expression._select_iterables(returning_cols)
+ )
# within_columns_clause=False so that labels (foo AS bar) don't render
- columns = [self.process(c, within_columns_clause=False)
- for c in self.returning_cols]
+ columns = [
+ self.process(c, within_columns_clause=False)
+ for c in self.returning_cols
+ ]
- if not hasattr(self, 'returning_parameters'):
+ if not hasattr(self, "returning_parameters"):
self.returning_parameters = []
binds = []
for i, col in enumerate(self.returning_cols):
- dbtype = col.type.dialect_impl(
- self.dialect).get_dbapi_type(self.dialect.dbapi)
+ dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(
+ self.dialect.dbapi
+ )
self.returning_parameters.append((i + 1, dbtype))
bindparam = sql.bindparam(
- "ret_%d" % i, value=ReturningParam(dbtype))
+ "ret_%d" % i, value=ReturningParam(dbtype)
+ )
self.binds[bindparam.key] = bindparam
binds.append(
- self.bindparam_string(self._truncate_bindparam(bindparam)))
+ self.bindparam_string(self._truncate_bindparam(bindparam))
+ )
- return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
+ return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds)
class OracleExecutionContext_zxjdbc(OracleExecutionContext):
-
def pre_exec(self):
- if hasattr(self.compiled, 'returning_parameters'):
+ if hasattr(self.compiled, "returning_parameters"):
# prepare a zxJDBC statement so we can grab its underlying
# OraclePreparedStatement's getReturnResultSet later
self.statement = self.cursor.prepare(self.statement)
def get_result_proxy(self):
- if hasattr(self.compiled, 'returning_parameters'):
+ if hasattr(self.compiled, "returning_parameters"):
rrs = None
try:
try:
rrs = self.statement.__statement__.getReturnResultSet()
next(rrs)
except SQLException as sqle:
- msg = '%s [SQLCode: %d]' % (
- sqle.getMessage(), sqle.getErrorCode())
+ msg = "%s [SQLCode: %d]" % (
+ sqle.getMessage(),
+ sqle.getErrorCode(),
+ )
if sqle.getSQLState() is not None:
- msg += ' [SQLState: %s]' % sqle.getSQLState()
+ msg += " [SQLState: %s]" % sqle.getSQLState()
raise zxJDBC.Error(msg)
else:
row = tuple(
- self.cursor.datahandler.getPyObject(
- rrs, index, dbtype)
- for index, dbtype in
- self.compiled.returning_parameters)
+ self.cursor.datahandler.getPyObject(rrs, index, dbtype)
+ for index, dbtype in self.compiled.returning_parameters
+ )
return ReturningResultProxy(self, row)
finally:
if rrs is not None:
@@ -146,7 +156,7 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy):
def _cursor_description(self):
ret = []
for c in self.context.compiled.returning_cols:
- if hasattr(c, 'name'):
+ if hasattr(c, "name"):
ret.append((c.name, c.type))
else:
ret.append((c.anon_label, c.type))
@@ -178,23 +188,24 @@ class ReturningParam(object):
def __repr__(self):
kls = self.__class__
- return '<%s.%s object at 0x%x type=%s>' % (
- kls.__module__, kls.__name__, id(self), self.type)
+ return "<%s.%s object at 0x%x type=%s>" % (
+ kls.__module__,
+ kls.__name__,
+ id(self),
+ self.type,
+ )
class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
- jdbc_db_name = 'oracle'
- jdbc_driver_name = 'oracle.jdbc.OracleDriver'
+ jdbc_db_name = "oracle"
+ jdbc_driver_name = "oracle.jdbc.OracleDriver"
statement_compiler = OracleCompiler_zxjdbc
execution_ctx_cls = OracleExecutionContext_zxjdbc
colspecs = util.update_copy(
OracleDialect.colspecs,
- {
- sqltypes.Date: _ZxJDBCDate,
- sqltypes.Numeric: _ZxJDBCNumeric
- }
+ {sqltypes.Date: _ZxJDBCDate, sqltypes.Numeric: _ZxJDBCNumeric},
)
def __init__(self, *args, **kwargs):
@@ -212,24 +223,31 @@ class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
statement.registerReturnParameter(index, object.type)
elif dbtype is None:
OracleDataHandler.setJDBCObject(
- self, statement, index, object)
+ self, statement, index, object
+ )
else:
OracleDataHandler.setJDBCObject(
- self, statement, index, object, dbtype)
+ self, statement, index, object, dbtype
+ )
+
self.DataHandler = OracleReturningDataHandler
def initialize(self, connection):
super(OracleDialect_zxjdbc, self).initialize(connection)
- self.implicit_returning = \
- connection.connection.driverversion >= '10.2'
+ self.implicit_returning = connection.connection.driverversion >= "10.2"
def _create_jdbc_url(self, url):
- return 'jdbc:oracle:thin:@%s:%s:%s' % (
- url.host, url.port or 1521, url.database)
+ return "jdbc:oracle:thin:@%s:%s:%s" % (
+ url.host,
+ url.port or 1521,
+ url.database,
+ )
def _get_server_version_info(self, connection):
version = re.search(
- r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
- return tuple(int(x) for x in version.split('.'))
+ r"Release ([\d\.]+)", connection.connection.dbversion
+ ).group(1)
+ return tuple(int(x) for x in version.split("."))
+
dialect = OracleDialect_zxjdbc
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
index 84f720028..9e65484fa 100644
--- a/lib/sqlalchemy/dialects/postgresql/__init__.py
+++ b/lib/sqlalchemy/dialects/postgresql/__init__.py
@@ -5,33 +5,110 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from . import base, psycopg2, pg8000, pypostgresql, pygresql, \
- zxjdbc, psycopg2cffi # noqa
+from . import (
+ base,
+ psycopg2,
+ pg8000,
+ pypostgresql,
+ pygresql,
+ zxjdbc,
+ psycopg2cffi,
+) # noqa
-from .base import \
- INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
- INET, CIDR, UUID, BIT, MACADDR, MONEY, OID, REGCLASS, DOUBLE_PRECISION, \
- TIMESTAMP, TIME, DATE, BYTEA, BOOLEAN, INTERVAL, ENUM, TSVECTOR, \
- DropEnumType, CreateEnumType
+from .base import (
+ INTEGER,
+ BIGINT,
+ SMALLINT,
+ VARCHAR,
+ CHAR,
+ TEXT,
+ NUMERIC,
+ FLOAT,
+ REAL,
+ INET,
+ CIDR,
+ UUID,
+ BIT,
+ MACADDR,
+ MONEY,
+ OID,
+ REGCLASS,
+ DOUBLE_PRECISION,
+ TIMESTAMP,
+ TIME,
+ DATE,
+ BYTEA,
+ BOOLEAN,
+ INTERVAL,
+ ENUM,
+ TSVECTOR,
+ DropEnumType,
+ CreateEnumType,
+)
from .hstore import HSTORE, hstore
from .json import JSON, JSONB
from .array import array, ARRAY, Any, All
from .ext import aggregate_order_by, ExcludeConstraint, array_agg
from .dml import insert, Insert
-from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
- TSTZRANGE
+from .ranges import (
+ INT4RANGE,
+ INT8RANGE,
+ NUMRANGE,
+ DATERANGE,
+ TSRANGE,
+ TSTZRANGE,
+)
base.dialect = dialect = psycopg2.dialect
__all__ = (
- 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC',
- 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', 'MONEY', 'OID',
- 'REGCLASS', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA',
- 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE',
- 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE',
- 'TSRANGE', 'TSTZRANGE', 'JSON', 'JSONB', 'Any', 'All',
- 'DropEnumType', 'CreateEnumType', 'ExcludeConstraint',
- 'aggregate_order_by', 'array_agg', 'insert', 'Insert'
+ "INTEGER",
+ "BIGINT",
+ "SMALLINT",
+ "VARCHAR",
+ "CHAR",
+ "TEXT",
+ "NUMERIC",
+ "FLOAT",
+ "REAL",
+ "INET",
+ "CIDR",
+ "UUID",
+ "BIT",
+ "MACADDR",
+ "MONEY",
+ "OID",
+ "REGCLASS",
+ "DOUBLE_PRECISION",
+ "TIMESTAMP",
+ "TIME",
+ "DATE",
+ "BYTEA",
+ "BOOLEAN",
+ "INTERVAL",
+ "ARRAY",
+ "ENUM",
+ "dialect",
+ "array",
+ "HSTORE",
+ "hstore",
+ "INT4RANGE",
+ "INT8RANGE",
+ "NUMRANGE",
+ "DATERANGE",
+ "TSRANGE",
+ "TSTZRANGE",
+ "JSON",
+ "JSONB",
+ "Any",
+ "All",
+ "DropEnumType",
+ "CreateEnumType",
+ "ExcludeConstraint",
+ "aggregate_order_by",
+ "array_agg",
+ "insert",
+ "Insert",
)
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
index b2674046e..07167f9d0 100644
--- a/lib/sqlalchemy/dialects/postgresql/array.py
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -78,7 +78,8 @@ class array(expression.Tuple):
:class:`.postgresql.ARRAY`
"""
- __visit_name__ = 'array'
+
+ __visit_name__ = "array"
def __init__(self, clauses, **kw):
super(array, self).__init__(*clauses, **kw)
@@ -90,18 +91,26 @@ class array(expression.Tuple):
# a Slice object from that
assert isinstance(obj, int)
return expression.BindParameter(
- None, obj, _compared_to_operator=operator,
+ None,
+ obj,
+ _compared_to_operator=operator,
type_=type_,
- _compared_to_type=self.type, unique=True)
+ _compared_to_type=self.type,
+ unique=True,
+ )
else:
- return array([
- self._bind_param(operator, o, _assume_scalar=True, type_=type_)
- for o in obj])
+ return array(
+ [
+ self._bind_param(
+ operator, o, _assume_scalar=True, type_=type_
+ )
+ for o in obj
+ ]
+ )
def self_group(self, against=None):
- if (against in (
- operators.any_op, operators.all_op, operators.getitem)):
+ if against in (operators.any_op, operators.all_op, operators.getitem):
return expression.Grouping(self)
else:
return self
@@ -180,7 +189,8 @@ class ARRAY(sqltypes.ARRAY):
elements of the argument array expression.
"""
return self.operate(
- CONTAINED_BY, other, result_type=sqltypes.Boolean)
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
def overlap(self, other):
"""Boolean expression. Test if array has elements in common with
@@ -190,8 +200,9 @@ class ARRAY(sqltypes.ARRAY):
comparator_factory = Comparator
- def __init__(self, item_type, as_tuple=False, dimensions=None,
- zero_indexes=False):
+ def __init__(
+ self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+ ):
"""Construct an ARRAY.
E.g.::
@@ -228,8 +239,10 @@ class ARRAY(sqltypes.ARRAY):
"""
if isinstance(item_type, ARRAY):
- raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
- "handles multi-dimensional arrays of basetype")
+ raise ValueError(
+ "Do not nest ARRAY types; ARRAY(basetype) "
+ "handles multi-dimensional arrays of basetype"
+ )
if isinstance(item_type, type):
item_type = item_type()
self.item_type = item_type
@@ -251,11 +264,17 @@ class ARRAY(sqltypes.ARRAY):
def _proc_array(self, arr, itemproc, dim, collection):
if dim is None:
arr = list(arr)
- if dim == 1 or dim is None and (
+ if (
+ dim == 1
+ or dim is None
+ and (
# this has to be (list, tuple), or at least
# not hasattr('__iter__'), since Py3K strings
# etc. have __iter__
- not arr or not isinstance(arr[0], (list, tuple))):
+ not arr
+ or not isinstance(arr[0], (list, tuple))
+ )
+ ):
if itemproc:
return collection(itemproc(x) for x in arr)
else:
@@ -263,30 +282,33 @@ class ARRAY(sqltypes.ARRAY):
else:
return collection(
self._proc_array(
- x, itemproc,
+ x,
+ itemproc,
dim - 1 if dim is not None else None,
- collection)
+ collection,
+ )
for x in arr
)
def bind_processor(self, dialect):
- item_proc = self.item_type.dialect_impl(dialect).\
- bind_processor(dialect)
+ item_proc = self.item_type.dialect_impl(dialect).bind_processor(
+ dialect
+ )
def process(value):
if value is None:
return value
else:
return self._proc_array(
- value,
- item_proc,
- self.dimensions,
- list)
+ value, item_proc, self.dimensions, list
+ )
+
return process
def result_processor(self, dialect, coltype):
- item_proc = self.item_type.dialect_impl(dialect).\
- result_processor(dialect, coltype)
+ item_proc = self.item_type.dialect_impl(dialect).result_processor(
+ dialect, coltype
+ )
def process(value):
if value is None:
@@ -296,8 +318,11 @@ class ARRAY(sqltypes.ARRAY):
value,
item_proc,
self.dimensions,
- tuple if self.as_tuple else list)
+ tuple if self.as_tuple else list,
+ )
+
return process
+
colspecs[sqltypes.ARRAY] = ARRAY
-ischema_names['_array'] = ARRAY
+ischema_names["_array"] = ARRAY
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index d68ab8ef5..11833da57 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -930,57 +930,164 @@ try:
except ImportError:
_python_UUID = None
-from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \
- CHAR, TEXT, FLOAT, NUMERIC, \
- DATE, BOOLEAN, REAL
+from sqlalchemy.types import (
+ INTEGER,
+ BIGINT,
+ SMALLINT,
+ VARCHAR,
+ CHAR,
+ TEXT,
+ FLOAT,
+ NUMERIC,
+ DATE,
+ BOOLEAN,
+ REAL,
+)
AUTOCOMMIT_REGEXP = re.compile(
- r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|'
- 'IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)',
- re.I | re.UNICODE)
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|"
+ "IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)",
+ re.I | re.UNICODE,
+)
RESERVED_WORDS = set(
- ["all", "analyse", "analyze", "and", "any", "array", "as", "asc",
- "asymmetric", "both", "case", "cast", "check", "collate", "column",
- "constraint", "create", "current_catalog", "current_date",
- "current_role", "current_time", "current_timestamp", "current_user",
- "default", "deferrable", "desc", "distinct", "do", "else", "end",
- "except", "false", "fetch", "for", "foreign", "from", "grant", "group",
- "having", "in", "initially", "intersect", "into", "leading", "limit",
- "localtime", "localtimestamp", "new", "not", "null", "of", "off",
- "offset", "old", "on", "only", "or", "order", "placing", "primary",
- "references", "returning", "select", "session_user", "some", "symmetric",
- "table", "then", "to", "trailing", "true", "union", "unique", "user",
- "using", "variadic", "when", "where", "window", "with", "authorization",
- "between", "binary", "cross", "current_schema", "freeze", "full",
- "ilike", "inner", "is", "isnull", "join", "left", "like", "natural",
- "notnull", "outer", "over", "overlaps", "right", "similar", "verbose"
- ])
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "current_catalog",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "fetch",
+ "for",
+ "foreign",
+ "from",
+ "grant",
+ "group",
+ "having",
+ "in",
+ "initially",
+ "intersect",
+ "into",
+ "leading",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "new",
+ "not",
+ "null",
+ "of",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "placing",
+ "primary",
+ "references",
+ "returning",
+ "select",
+ "session_user",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "variadic",
+ "when",
+ "where",
+ "window",
+ "with",
+ "authorization",
+ "between",
+ "binary",
+ "cross",
+ "current_schema",
+ "freeze",
+ "full",
+ "ilike",
+ "inner",
+ "is",
+ "isnull",
+ "join",
+ "left",
+ "like",
+ "natural",
+ "notnull",
+ "outer",
+ "over",
+ "overlaps",
+ "right",
+ "similar",
+ "verbose",
+ ]
+)
_DECIMAL_TYPES = (1231, 1700)
_FLOAT_TYPES = (700, 701, 1021, 1022)
_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
+
class BYTEA(sqltypes.LargeBinary):
- __visit_name__ = 'BYTEA'
+ __visit_name__ = "BYTEA"
class DOUBLE_PRECISION(sqltypes.Float):
- __visit_name__ = 'DOUBLE_PRECISION'
+ __visit_name__ = "DOUBLE_PRECISION"
class INET(sqltypes.TypeEngine):
__visit_name__ = "INET"
+
+
PGInet = INET
class CIDR(sqltypes.TypeEngine):
__visit_name__ = "CIDR"
+
+
PGCidr = CIDR
class MACADDR(sqltypes.TypeEngine):
__visit_name__ = "MACADDR"
+
+
PGMacAddr = MACADDR
@@ -991,6 +1098,7 @@ class MONEY(sqltypes.TypeEngine):
.. versionadded:: 1.2
"""
+
__visit_name__ = "MONEY"
@@ -1001,6 +1109,7 @@ class OID(sqltypes.TypeEngine):
.. versionadded:: 0.9.5
"""
+
__visit_name__ = "OID"
@@ -1011,18 +1120,17 @@ class REGCLASS(sqltypes.TypeEngine):
.. versionadded:: 1.2.7
"""
+
__visit_name__ = "REGCLASS"
class TIMESTAMP(sqltypes.TIMESTAMP):
-
def __init__(self, timezone=False, precision=None):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision
class TIME(sqltypes.TIME):
-
def __init__(self, timezone=False, precision=None):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision
@@ -1036,7 +1144,8 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
It is known to work on psycopg2 and not pg8000 or zxjdbc.
"""
- __visit_name__ = 'INTERVAL'
+
+ __visit_name__ = "INTERVAL"
native = True
def __init__(self, precision=None, fields=None):
@@ -1065,11 +1174,12 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
def python_type(self):
return dt.timedelta
+
PGInterval = INTERVAL
class BIT(sqltypes.TypeEngine):
- __visit_name__ = 'BIT'
+ __visit_name__ = "BIT"
def __init__(self, length=None, varying=False):
if not varying:
@@ -1080,6 +1190,7 @@ class BIT(sqltypes.TypeEngine):
self.length = length
self.varying = varying
+
PGBit = BIT
@@ -1095,7 +1206,8 @@ class UUID(sqltypes.TypeEngine):
It is known to work on psycopg2 and not pg8000.
"""
- __visit_name__ = 'UUID'
+
+ __visit_name__ = "UUID"
def __init__(self, as_uuid=False):
"""Construct a UUID type.
@@ -1115,24 +1227,29 @@ class UUID(sqltypes.TypeEngine):
def bind_processor(self, dialect):
if self.as_uuid:
+
def process(value):
if value is not None:
value = util.text_type(value)
return value
+
return process
else:
return None
def result_processor(self, dialect, coltype):
if self.as_uuid:
+
def process(value):
if value is not None:
value = _python_UUID(value)
return value
+
return process
else:
return None
+
PGUuid = UUID
@@ -1151,7 +1268,8 @@ class TSVECTOR(sqltypes.TypeEngine):
:ref:`postgresql_match`
"""
- __visit_name__ = 'TSVECTOR'
+
+ __visit_name__ = "TSVECTOR"
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
@@ -1273,12 +1391,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
"""
kw.setdefault("validate_strings", impl.validate_strings)
- kw.setdefault('name', impl.name)
- kw.setdefault('schema', impl.schema)
- kw.setdefault('inherit_schema', impl.inherit_schema)
- kw.setdefault('metadata', impl.metadata)
- kw.setdefault('_create_events', False)
- kw.setdefault('values_callable', impl.values_callable)
+ kw.setdefault("name", impl.name)
+ kw.setdefault("schema", impl.schema)
+ kw.setdefault("inherit_schema", impl.inherit_schema)
+ kw.setdefault("metadata", impl.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("values_callable", impl.values_callable)
return cls(**kw)
def create(self, bind=None, checkfirst=True):
@@ -1300,9 +1418,9 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
if not bind.dialect.supports_native_enum:
return
- if not checkfirst or \
- not bind.dialect.has_type(
- bind, self.name, schema=self.schema):
+ if not checkfirst or not bind.dialect.has_type(
+ bind, self.name, schema=self.schema
+ ):
bind.execute(CreateEnumType(self))
def drop(self, bind=None, checkfirst=True):
@@ -1323,8 +1441,9 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
if not bind.dialect.supports_native_enum:
return
- if not checkfirst or \
- bind.dialect.has_type(bind, self.name, schema=self.schema):
+ if not checkfirst or bind.dialect.has_type(
+ bind, self.name, schema=self.schema
+ ):
bind.execute(DropEnumType(self))
def _check_for_name_in_memos(self, checkfirst, kw):
@@ -1338,12 +1457,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
"""
if not self.create_type:
return True
- if '_ddl_runner' in kw:
- ddl_runner = kw['_ddl_runner']
- if '_pg_enums' in ddl_runner.memo:
- pg_enums = ddl_runner.memo['_pg_enums']
+ if "_ddl_runner" in kw:
+ ddl_runner = kw["_ddl_runner"]
+ if "_pg_enums" in ddl_runner.memo:
+ pg_enums = ddl_runner.memo["_pg_enums"]
else:
- pg_enums = ddl_runner.memo['_pg_enums'] = set()
+ pg_enums = ddl_runner.memo["_pg_enums"] = set()
present = (self.schema, self.name) in pg_enums
pg_enums.add((self.schema, self.name))
return present
@@ -1351,16 +1470,22 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
return False
def _on_table_create(self, target, bind, checkfirst=False, **kw):
- if checkfirst or (
- not self.metadata and
- not kw.get('_is_metadata_operation', False)) and \
- not self._check_for_name_in_memos(checkfirst, kw):
+ if (
+ checkfirst
+ or (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ )
+ and not self._check_for_name_in_memos(checkfirst, kw)
+ ):
self.create(bind=bind, checkfirst=checkfirst)
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
- if not self.metadata and \
- not kw.get('_is_metadata_operation', False) and \
- not self._check_for_name_in_memos(checkfirst, kw):
+ if (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ and not self._check_for_name_in_memos(checkfirst, kw)
+ ):
self.drop(bind=bind, checkfirst=checkfirst)
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
@@ -1371,49 +1496,46 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
if not self._check_for_name_in_memos(checkfirst, kw):
self.drop(bind=bind, checkfirst=checkfirst)
-colspecs = {
- sqltypes.Interval: INTERVAL,
- sqltypes.Enum: ENUM,
-}
+
+colspecs = {sqltypes.Interval: INTERVAL, sqltypes.Enum: ENUM}
ischema_names = {
- 'integer': INTEGER,
- 'bigint': BIGINT,
- 'smallint': SMALLINT,
- 'character varying': VARCHAR,
- 'character': CHAR,
+ "integer": INTEGER,
+ "bigint": BIGINT,
+ "smallint": SMALLINT,
+ "character varying": VARCHAR,
+ "character": CHAR,
'"char"': sqltypes.String,
- 'name': sqltypes.String,
- 'text': TEXT,
- 'numeric': NUMERIC,
- 'float': FLOAT,
- 'real': REAL,
- 'inet': INET,
- 'cidr': CIDR,
- 'uuid': UUID,
- 'bit': BIT,
- 'bit varying': BIT,
- 'macaddr': MACADDR,
- 'money': MONEY,
- 'oid': OID,
- 'regclass': REGCLASS,
- 'double precision': DOUBLE_PRECISION,
- 'timestamp': TIMESTAMP,
- 'timestamp with time zone': TIMESTAMP,
- 'timestamp without time zone': TIMESTAMP,
- 'time with time zone': TIME,
- 'time without time zone': TIME,
- 'date': DATE,
- 'time': TIME,
- 'bytea': BYTEA,
- 'boolean': BOOLEAN,
- 'interval': INTERVAL,
- 'tsvector': TSVECTOR
+ "name": sqltypes.String,
+ "text": TEXT,
+ "numeric": NUMERIC,
+ "float": FLOAT,
+ "real": REAL,
+ "inet": INET,
+ "cidr": CIDR,
+ "uuid": UUID,
+ "bit": BIT,
+ "bit varying": BIT,
+ "macaddr": MACADDR,
+ "money": MONEY,
+ "oid": OID,
+ "regclass": REGCLASS,
+ "double precision": DOUBLE_PRECISION,
+ "timestamp": TIMESTAMP,
+ "timestamp with time zone": TIMESTAMP,
+ "timestamp without time zone": TIMESTAMP,
+ "time with time zone": TIME,
+ "time without time zone": TIME,
+ "date": DATE,
+ "time": TIME,
+ "bytea": BYTEA,
+ "boolean": BOOLEAN,
+ "interval": INTERVAL,
+ "tsvector": TSVECTOR,
}
class PGCompiler(compiler.SQLCompiler):
-
def visit_array(self, element, **kw):
return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
@@ -1424,77 +1546,75 @@ class PGCompiler(compiler.SQLCompiler):
)
def visit_json_getitem_op_binary(self, binary, operator, **kw):
- kw['eager_grouping'] = True
- return self._generate_generic_binary(
- binary, " -> ", **kw
- )
+ kw["eager_grouping"] = True
+ return self._generate_generic_binary(binary, " -> ", **kw)
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
- kw['eager_grouping'] = True
- return self._generate_generic_binary(
- binary, " #> ", **kw
- )
+ kw["eager_grouping"] = True
+ return self._generate_generic_binary(binary, " #> ", **kw)
def visit_getitem_binary(self, binary, operator, **kw):
return "%s[%s]" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def visit_aggregate_order_by(self, element, **kw):
return "%s ORDER BY %s" % (
self.process(element.target, **kw),
- self.process(element.order_by, **kw)
+ self.process(element.order_by, **kw),
)
def visit_match_op_binary(self, binary, operator, **kw):
if "postgresql_regconfig" in binary.modifiers:
regconfig = self.render_literal_value(
- binary.modifiers['postgresql_regconfig'],
- sqltypes.STRINGTYPE)
+ binary.modifiers["postgresql_regconfig"], sqltypes.STRINGTYPE
+ )
if regconfig:
return "%s @@ to_tsquery(%s, %s)" % (
self.process(binary.left, **kw),
regconfig,
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
return "%s @@ to_tsquery(%s)" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s ILIKE %s' % \
- (self.process(binary.left, **kw),
- self.process(binary.right, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ return "%s ILIKE %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s NOT ILIKE %s' % \
- (self.process(binary.left, **kw),
- self.process(binary.right, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ return "%s NOT ILIKE %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_empty_set_expr(self, element_types):
# cast the empty set to the type we are comparing against. if
# we are comparing against the null type, pick an arbitrary
# datatype for the empty set
- return 'SELECT %s WHERE 1!=1' % (
+ return "SELECT %s WHERE 1!=1" % (
", ".join(
- "CAST(NULL AS %s)" % self.dialect.type_compiler.process(
- INTEGER() if type_._isnull else type_,
- ) for type_ in element_types or [INTEGER()]
+ "CAST(NULL AS %s)"
+ % self.dialect.type_compiler.process(
+ INTEGER() if type_._isnull else type_
+ )
+ for type_ in element_types or [INTEGER()]
),
)
@@ -1502,7 +1622,7 @@ class PGCompiler(compiler.SQLCompiler):
value = super(PGCompiler, self).render_literal_value(value, type_)
if self.dialect._backslash_escapes:
- value = value.replace('\\', '\\\\')
+ value = value.replace("\\", "\\\\")
return value
def visit_sequence(self, seq, **kw):
@@ -1519,7 +1639,7 @@ class PGCompiler(compiler.SQLCompiler):
return text
def format_from_hint_text(self, sqltext, table, hint, iscrud):
- if hint.upper() != 'ONLY':
+ if hint.upper() != "ONLY":
raise exc.CompileError("Unrecognized hint: %r" % hint)
return "ONLY " + sqltext
@@ -1528,12 +1648,19 @@ class PGCompiler(compiler.SQLCompiler):
if select._distinct is True:
return "DISTINCT "
elif isinstance(select._distinct, (list, tuple)):
- return "DISTINCT ON (" + ', '.join(
- [self.process(col, **kw) for col in select._distinct]
- ) + ") "
+ return (
+ "DISTINCT ON ("
+ + ", ".join(
+ [self.process(col, **kw) for col in select._distinct]
+ )
+ + ") "
+ )
else:
- return "DISTINCT ON (" + \
- self.process(select._distinct, **kw) + ") "
+ return (
+ "DISTINCT ON ("
+ + self.process(select._distinct, **kw)
+ + ") "
+ )
else:
return ""
@@ -1551,8 +1678,9 @@ class PGCompiler(compiler.SQLCompiler):
if select._for_update_arg.of:
tables = util.OrderedSet(
- c.table if isinstance(c, expression.ColumnClause)
- else c for c in select._for_update_arg.of)
+ c.table if isinstance(c, expression.ColumnClause) else c
+ for c in select._for_update_arg.of
+ )
tmp += " OF " + ", ".join(
self.process(table, ashint=True, use_schema=False, **kw)
for table in tables
@@ -1572,7 +1700,7 @@ class PGCompiler(compiler.SQLCompiler):
for c in expression._select_iterables(returning_cols)
]
- return 'RETURNING ' + ', '.join(columns)
+ return "RETURNING " + ", ".join(columns)
def visit_substring_func(self, func, **kw):
s = self.process(func.clauses.clauses[0], **kw)
@@ -1586,24 +1714,24 @@ class PGCompiler(compiler.SQLCompiler):
def _on_conflict_target(self, clause, **kw):
if clause.constraint_target is not None:
- target_text = 'ON CONSTRAINT %s' % clause.constraint_target
+ target_text = "ON CONSTRAINT %s" % clause.constraint_target
elif clause.inferred_target_elements is not None:
- target_text = '(%s)' % ', '.join(
- (self.preparer.quote(c)
- if isinstance(c, util.string_types)
- else
- self.process(c, include_table=False, use_schema=False))
+ target_text = "(%s)" % ", ".join(
+ (
+ self.preparer.quote(c)
+ if isinstance(c, util.string_types)
+ else self.process(c, include_table=False, use_schema=False)
+ )
for c in clause.inferred_target_elements
)
if clause.inferred_target_whereclause is not None:
- target_text += ' WHERE %s' % \
- self.process(
- clause.inferred_target_whereclause,
- include_table=False,
- use_schema=False
- )
+ target_text += " WHERE %s" % self.process(
+ clause.inferred_target_whereclause,
+ include_table=False,
+ use_schema=False,
+ )
else:
- target_text = ''
+ target_text = ""
return target_text
@@ -1627,36 +1755,35 @@ class PGCompiler(compiler.SQLCompiler):
set_parameters = dict(clause.update_values_to_set)
# create a list of column assignment clauses as tuples
- insert_statement = self.stack[-1]['selectable']
+ insert_statement = self.stack[-1]["selectable"]
cols = insert_statement.table.c
for c in cols:
col_key = c.key
if col_key in set_parameters:
value = set_parameters.pop(col_key)
if elements._is_literal(value):
- value = elements.BindParameter(
- None, value, type_=c.type
- )
+ value = elements.BindParameter(None, value, type_=c.type)
else:
- if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
+ if (
+ isinstance(value, elements.BindParameter)
+ and value.type._isnull
+ ):
value = value._clone()
value.type = c.type
value_text = self.process(value.self_group(), use_schema=False)
- key_text = (
- self.preparer.quote(col_key)
- )
- action_set_ops.append('%s = %s' % (key_text, value_text))
+ key_text = self.preparer.quote(col_key)
+ action_set_ops.append("%s = %s" % (key_text, value_text))
# check for names that don't match columns
if set_parameters:
util.warn(
"Additional column names not matching "
- "any column keys in table '%s': %s" % (
+ "any column keys in table '%s': %s"
+ % (
self.statement.table.name,
- (", ".join("'%s'" % c for c in set_parameters))
+ (", ".join("'%s'" % c for c in set_parameters)),
)
)
for k, v in set_parameters.items():
@@ -1666,42 +1793,37 @@ class PGCompiler(compiler.SQLCompiler):
else self.process(k, use_schema=False)
)
value_text = self.process(
- elements._literal_as_binds(v),
- use_schema=False
+ elements._literal_as_binds(v), use_schema=False
)
- action_set_ops.append('%s = %s' % (key_text, value_text))
+ action_set_ops.append("%s = %s" % (key_text, value_text))
- action_text = ', '.join(action_set_ops)
+ action_text = ", ".join(action_set_ops)
if clause.update_whereclause is not None:
- action_text += ' WHERE %s' % \
- self.process(
- clause.update_whereclause,
- include_table=True,
- use_schema=False
- )
+ action_text += " WHERE %s" % self.process(
+ clause.update_whereclause, include_table=True, use_schema=False
+ )
- return 'ON CONFLICT %s DO UPDATE SET %s' % (target_text, action_text)
+ return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
- def delete_extra_from_clause(self, delete_stmt, from_table,
- extra_froms, from_hints, **kw):
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Render the DELETE .. USING clause specific to PostgreSQL."""
- return "USING " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ return "USING " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
class PGDDLCompiler(compiler.DDLCompiler):
-
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
@@ -1709,17 +1831,21 @@ class PGDDLCompiler(compiler.DDLCompiler):
if isinstance(impl_type, sqltypes.TypeDecorator):
impl_type = impl_type.impl
- if column.primary_key and \
- column is column.table._autoincrement_column and \
- (
- self.dialect.supports_smallserial or
- not isinstance(impl_type, sqltypes.SmallInteger)
- ) and (
- column.default is None or
- (
- isinstance(column.default, schema.Sequence) and
- column.default.optional
- )):
+ if (
+ column.primary_key
+ and column is column.table._autoincrement_column
+ and (
+ self.dialect.supports_smallserial
+ or not isinstance(impl_type, sqltypes.SmallInteger)
+ )
+ and (
+ column.default is None
+ or (
+ isinstance(column.default, schema.Sequence)
+ and column.default.optional
+ )
+ )
+ ):
if isinstance(impl_type, sqltypes.BigInteger):
colspec += " BIGSERIAL"
elif isinstance(impl_type, sqltypes.SmallInteger):
@@ -1728,7 +1854,8 @@ class PGDDLCompiler(compiler.DDLCompiler):
colspec += " SERIAL"
else:
colspec += " " + self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ column.type, type_expression=column
+ )
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -1744,15 +1871,14 @@ class PGDDLCompiler(compiler.DDLCompiler):
self.preparer.format_type(type_),
", ".join(
self.sql_compiler.process(sql.literal(e), literal_binds=True)
- for e in type_.enums)
+ for e in type_.enums
+ ),
)
def visit_drop_enum_type(self, drop):
type_ = drop.element
- return "DROP TYPE %s" % (
- self.preparer.format_type(type_)
- )
+ return "DROP TYPE %s" % (self.preparer.format_type(type_))
def visit_create_index(self, create):
preparer = self.preparer
@@ -1764,46 +1890,53 @@ class PGDDLCompiler(compiler.DDLCompiler):
text += "INDEX "
if self.dialect._supports_create_index_concurrently:
- concurrently = index.dialect_options['postgresql']['concurrently']
+ concurrently = index.dialect_options["postgresql"]["concurrently"]
if concurrently:
text += "CONCURRENTLY "
text += "%s ON %s " % (
- self._prepared_index_name(index,
- include_schema=False),
- preparer.format_table(index.table)
+ self._prepared_index_name(index, include_schema=False),
+ preparer.format_table(index.table),
)
- using = index.dialect_options['postgresql']['using']
+ using = index.dialect_options["postgresql"]["using"]
if using:
text += "USING %s " % preparer.quote(using)
ops = index.dialect_options["postgresql"]["ops"]
- text += "(%s)" \
- % (
- ', '.join([
- self.sql_compiler.process(
- expr.self_group()
- if not isinstance(expr, expression.ColumnClause)
- else expr,
- include_table=False, literal_binds=True) +
- (
- (' ' + ops[expr.key])
- if hasattr(expr, 'key')
- and expr.key in ops else ''
- )
- for expr in index.expressions
- ])
- )
+ text += "(%s)" % (
+ ", ".join(
+ [
+ self.sql_compiler.process(
+ expr.self_group()
+ if not isinstance(expr, expression.ColumnClause)
+ else expr,
+ include_table=False,
+ literal_binds=True,
+ )
+ + (
+ (" " + ops[expr.key])
+ if hasattr(expr, "key") and expr.key in ops
+ else ""
+ )
+ for expr in index.expressions
+ ]
+ )
+ )
- withclause = index.dialect_options['postgresql']['with']
+ withclause = index.dialect_options["postgresql"]["with"]
if withclause:
- text += " WITH (%s)" % (', '.join(
- ['%s = %s' % storage_parameter
- for storage_parameter in withclause.items()]))
+ text += " WITH (%s)" % (
+ ", ".join(
+ [
+ "%s = %s" % storage_parameter
+ for storage_parameter in withclause.items()
+ ]
+ )
+ )
- tablespace_name = index.dialect_options['postgresql']['tablespace']
+ tablespace_name = index.dialect_options["postgresql"]["tablespace"]
if tablespace_name:
text += " TABLESPACE %s" % preparer.quote(tablespace_name)
@@ -1812,8 +1945,8 @@ class PGDDLCompiler(compiler.DDLCompiler):
if whereclause is not None:
where_compiled = self.sql_compiler.process(
- whereclause, include_table=False,
- literal_binds=True)
+ whereclause, include_table=False, literal_binds=True
+ )
text += " WHERE " + where_compiled
return text
@@ -1823,7 +1956,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
text = "\nDROP INDEX "
if self.dialect._supports_drop_index_concurrently:
- concurrently = index.dialect_options['postgresql']['concurrently']
+ concurrently = index.dialect_options["postgresql"]["concurrently"]
if concurrently:
text += "CONCURRENTLY "
@@ -1833,55 +1966,59 @@ class PGDDLCompiler(compiler.DDLCompiler):
def visit_exclude_constraint(self, constraint, **kw):
text = ""
if constraint.name is not None:
- text += "CONSTRAINT %s " % \
- self.preparer.format_constraint(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
elements = []
for expr, name, op in constraint._render_exprs:
- kw['include_table'] = False
+ kw["include_table"] = False
elements.append(
"%s WITH %s" % (self.sql_compiler.process(expr, **kw), op)
)
- text += "EXCLUDE USING %s (%s)" % (constraint.using,
- ', '.join(elements))
+ text += "EXCLUDE USING %s (%s)" % (
+ constraint.using,
+ ", ".join(elements),
+ )
if constraint.where is not None:
- text += ' WHERE (%s)' % self.sql_compiler.process(
- constraint.where,
- literal_binds=True)
+ text += " WHERE (%s)" % self.sql_compiler.process(
+ constraint.where, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
def post_create_table(self, table):
table_opts = []
- pg_opts = table.dialect_options['postgresql']
+ pg_opts = table.dialect_options["postgresql"]
- inherits = pg_opts.get('inherits')
+ inherits = pg_opts.get("inherits")
if inherits is not None:
if not isinstance(inherits, (list, tuple)):
- inherits = (inherits, )
+ inherits = (inherits,)
table_opts.append(
- '\n INHERITS ( ' +
- ', '.join(self.preparer.quote(name) for name in inherits) +
- ' )')
+ "\n INHERITS ( "
+ + ", ".join(self.preparer.quote(name) for name in inherits)
+ + " )"
+ )
- if pg_opts['partition_by']:
- table_opts.append('\n PARTITION BY %s' % pg_opts['partition_by'])
+ if pg_opts["partition_by"]:
+ table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"])
- if pg_opts['with_oids'] is True:
- table_opts.append('\n WITH OIDS')
- elif pg_opts['with_oids'] is False:
- table_opts.append('\n WITHOUT OIDS')
+ if pg_opts["with_oids"] is True:
+ table_opts.append("\n WITH OIDS")
+ elif pg_opts["with_oids"] is False:
+ table_opts.append("\n WITHOUT OIDS")
- if pg_opts['on_commit']:
- on_commit_options = pg_opts['on_commit'].replace("_", " ").upper()
- table_opts.append('\n ON COMMIT %s' % on_commit_options)
+ if pg_opts["on_commit"]:
+ on_commit_options = pg_opts["on_commit"].replace("_", " ").upper()
+ table_opts.append("\n ON COMMIT %s" % on_commit_options)
- if pg_opts['tablespace']:
- tablespace_name = pg_opts['tablespace']
+ if pg_opts["tablespace"]:
+ tablespace_name = pg_opts["tablespace"]
table_opts.append(
- '\n TABLESPACE %s' % self.preparer.quote(tablespace_name)
+ "\n TABLESPACE %s" % self.preparer.quote(tablespace_name)
)
- return ''.join(table_opts)
+ return "".join(table_opts)
class PGTypeCompiler(compiler.GenericTypeCompiler):
@@ -1910,7 +2047,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
if not type_.precision:
return "FLOAT"
else:
- return "FLOAT(%(precision)s)" % {'precision': type_.precision}
+ return "FLOAT(%(precision)s)" % {"precision": type_.precision}
def visit_DOUBLE_PRECISION(self, type_, **kw):
return "DOUBLE PRECISION"
@@ -1960,15 +2097,17 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_TIMESTAMP(self, type_, **kw):
return "TIMESTAMP%s %s" % (
"(%d)" % type_.precision
- if getattr(type_, 'precision', None) is not None else "",
- (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+ if getattr(type_, "precision", None) is not None
+ else "",
+ (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",
)
def visit_TIME(self, type_, **kw):
return "TIME%s %s" % (
"(%d)" % type_.precision
- if getattr(type_, 'precision', None) is not None else "",
- (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+ if getattr(type_, "precision", None) is not None
+ else "",
+ (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",
)
def visit_INTERVAL(self, type_, **kw):
@@ -2002,13 +2141,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
# TODO: pass **kw?
inner = self.process(type_.item_type)
return re.sub(
- r'((?: COLLATE.*)?)$',
- (r'%s\1' % (
- "[]" *
- (type_.dimensions if type_.dimensions is not None else 1)
- )),
+ r"((?: COLLATE.*)?)$",
+ (
+ r"%s\1"
+ % (
+ "[]"
+ * (type_.dimensions if type_.dimensions is not None else 1)
+ )
+ ),
inner,
- count=1
+ count=1,
)
@@ -2018,8 +2160,9 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
def _unquote_identifier(self, value):
if value[0] == self.initial_quote:
- value = value[1:-1].\
- replace(self.escape_to_quote, self.escape_quote)
+ value = value[1:-1].replace(
+ self.escape_to_quote, self.escape_quote
+ )
return value
def format_type(self, type_, use_schema=True):
@@ -2029,22 +2172,25 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
name = self.quote(type_.name)
effective_schema = self.schema_for_object(type_)
- if not self.omit_schema and use_schema and \
- effective_schema is not None:
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
name = self.quote_schema(effective_schema) + "." + name
return name
class PGInspector(reflection.Inspector):
-
def __init__(self, conn):
reflection.Inspector.__init__(self, conn)
def get_table_oid(self, table_name, schema=None):
"""Return the OID for the given table name."""
- return self.dialect.get_table_oid(self.bind, table_name, schema,
- info_cache=self.info_cache)
+ return self.dialect.get_table_oid(
+ self.bind, table_name, schema, info_cache=self.info_cache
+ )
def get_enums(self, schema=None):
"""Return a list of ENUM objects.
@@ -2080,7 +2226,7 @@ class PGInspector(reflection.Inspector):
schema = schema or self.default_schema_name
return self.dialect._get_foreign_table_names(self.bind, schema)
- def get_view_names(self, schema=None, include=('plain', 'materialized')):
+ def get_view_names(self, schema=None, include=("plain", "materialized")):
"""Return all view names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
@@ -2094,9 +2240,9 @@ class PGInspector(reflection.Inspector):
"""
- return self.dialect.get_view_names(self.bind, schema,
- info_cache=self.info_cache,
- include=include)
+ return self.dialect.get_view_names(
+ self.bind, schema, info_cache=self.info_cache, include=include
+ )
class CreateEnumType(schema._CreateDropBase):
@@ -2108,25 +2254,27 @@ class DropEnumType(schema._CreateDropBase):
class PGExecutionContext(default.DefaultExecutionContext):
-
def fire_sequence(self, seq, type_):
- return self._execute_scalar((
- "select nextval('%s')" %
- self.dialect.identifier_preparer.format_sequence(seq)), type_)
+ return self._execute_scalar(
+ (
+ "select nextval('%s')"
+ % self.dialect.identifier_preparer.format_sequence(seq)
+ ),
+ type_,
+ )
def get_insert_default(self, column):
- if column.primary_key and \
- column is column.table._autoincrement_column:
+ if column.primary_key and column is column.table._autoincrement_column:
if column.server_default and column.server_default.has_argument:
# pre-execute passive defaults on primary key columns
- return self._execute_scalar("select %s" %
- column.server_default.arg,
- column.type)
+ return self._execute_scalar(
+ "select %s" % column.server_default.arg, column.type
+ )
- elif (column.default is None or
- (column.default.is_sequence and
- column.default.optional)):
+ elif column.default is None or (
+ column.default.is_sequence and column.default.optional
+ ):
# execute the sequence associated with a SERIAL primary
# key column. for non-primary-key SERIAL, the ID just
@@ -2137,23 +2285,25 @@ class PGExecutionContext(default.DefaultExecutionContext):
except AttributeError:
tab = column.table.name
col = column.name
- tab = tab[0:29 + max(0, (29 - len(col)))]
- col = col[0:29 + max(0, (29 - len(tab)))]
+ tab = tab[0 : 29 + max(0, (29 - len(col)))]
+ col = col[0 : 29 + max(0, (29 - len(tab)))]
name = "%s_%s_seq" % (tab, col)
column._postgresql_seq_name = seq_name = name
if column.table is not None:
effective_schema = self.connection.schema_for_object(
- column.table)
+ column.table
+ )
else:
effective_schema = None
if effective_schema is not None:
- exc = "select nextval('\"%s\".\"%s\"')" % \
- (effective_schema, seq_name)
+ exc = 'select nextval(\'"%s"."%s"\')' % (
+ effective_schema,
+ seq_name,
+ )
else:
- exc = "select nextval('\"%s\"')" % \
- (seq_name, )
+ exc = "select nextval('\"%s\"')" % (seq_name,)
return self._execute_scalar(exc, column.type)
@@ -2164,7 +2314,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
class PGDialect(default.DefaultDialect):
- name = 'postgresql'
+ name = "postgresql"
supports_alter = True
max_identifier_length = 63
supports_sane_rowcount = True
@@ -2182,7 +2332,7 @@ class PGDialect(default.DefaultDialect):
supports_default_values = True
supports_empty_insert = False
supports_multivalues_insert = True
- default_paramstyle = 'pyformat'
+ default_paramstyle = "pyformat"
ischema_names = ischema_names
colspecs = colspecs
@@ -2195,32 +2345,43 @@ class PGDialect(default.DefaultDialect):
isolation_level = None
construct_arguments = [
- (schema.Index, {
- "using": False,
- "where": None,
- "ops": {},
- "concurrently": False,
- "with": {},
- "tablespace": None
- }),
- (schema.Table, {
- "ignore_search_path": False,
- "tablespace": None,
- "partition_by": None,
- "with_oids": None,
- "on_commit": None,
- "inherits": None
- }),
+ (
+ schema.Index,
+ {
+ "using": False,
+ "where": None,
+ "ops": {},
+ "concurrently": False,
+ "with": {},
+ "tablespace": None,
+ },
+ ),
+ (
+ schema.Table,
+ {
+ "ignore_search_path": False,
+ "tablespace": None,
+ "partition_by": None,
+ "with_oids": None,
+ "on_commit": None,
+ "inherits": None,
+ },
+ ),
]
- reflection_options = ('postgresql_ignore_search_path', )
+ reflection_options = ("postgresql_ignore_search_path",)
_backslash_escapes = True
_supports_create_index_concurrently = True
_supports_drop_index_concurrently = True
- def __init__(self, isolation_level=None, json_serializer=None,
- json_deserializer=None, **kwargs):
+ def __init__(
+ self,
+ isolation_level=None,
+ json_serializer=None,
+ json_deserializer=None,
+ **kwargs
+ ):
default.DefaultDialect.__init__(self, **kwargs)
self.isolation_level = isolation_level
self._json_deserializer = json_deserializer
@@ -2228,8 +2389,10 @@ class PGDialect(default.DefaultDialect):
def initialize(self, connection):
super(PGDialect, self).initialize(connection)
- self.implicit_returning = self.server_version_info > (8, 2) and \
- self.__dict__.get('implicit_returning', True)
+ self.implicit_returning = self.server_version_info > (
+ 8,
+ 2,
+ ) and self.__dict__.get("implicit_returning", True)
self.supports_native_enum = self.server_version_info >= (8, 3)
if not self.supports_native_enum:
self.colspecs = self.colspecs.copy()
@@ -2241,45 +2404,57 @@ class PGDialect(default.DefaultDialect):
# http://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
self.supports_smallserial = self.server_version_info >= (9, 2)
- self._backslash_escapes = self.server_version_info < (8, 2) or \
- connection.scalar(
- "show standard_conforming_strings"
- ) == 'off'
+ self._backslash_escapes = (
+ self.server_version_info < (8, 2)
+ or connection.scalar("show standard_conforming_strings") == "off"
+ )
- self._supports_create_index_concurrently = \
+ self._supports_create_index_concurrently = (
self.server_version_info >= (8, 2)
- self._supports_drop_index_concurrently = \
- self.server_version_info >= (9, 2)
+ )
+ self._supports_drop_index_concurrently = self.server_version_info >= (
+ 9,
+ 2,
+ )
def on_connect(self):
if self.isolation_level is not None:
+
def connect(conn):
self.set_isolation_level(conn, self.isolation_level)
+
return connect
else:
return None
- _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
- 'READ COMMITTED', 'REPEATABLE READ'])
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ ]
+ )
def set_isolation_level(self, connection, level):
- level = level.replace('_', ' ')
+ level = level.replace("_", " ")
if level not in self._isolation_lookup:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
- "Valid isolation levels for %s are %s" %
- (level, self.name, ", ".join(self._isolation_lookup))
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
)
cursor = connection.cursor()
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION "
- "ISOLATION LEVEL %s" % level)
+ "ISOLATION LEVEL %s" % level
+ )
cursor.execute("COMMIT")
cursor.close()
def get_isolation_level(self, connection):
cursor = connection.cursor()
- cursor.execute('show transaction isolation level')
+ cursor.execute("show transaction isolation level")
val = cursor.fetchone()[0]
cursor.close()
return val.upper()
@@ -2290,8 +2465,9 @@ class PGDialect(default.DefaultDialect):
def do_prepare_twophase(self, connection, xid):
connection.execute("PREPARE TRANSACTION '%s'" % xid)
- def do_rollback_twophase(self, connection, xid,
- is_prepared=True, recover=False):
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if is_prepared:
if recover:
# FIXME: ugly hack to get out of transaction
@@ -2305,8 +2481,9 @@ class PGDialect(default.DefaultDialect):
else:
self.do_rollback(connection.connection)
- def do_commit_twophase(self, connection, xid,
- is_prepared=True, recover=False):
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
if is_prepared:
if recover:
connection.execute("ROLLBACK")
@@ -2318,22 +2495,27 @@ class PGDialect(default.DefaultDialect):
def do_recover_twophase(self, connection):
resultset = connection.execute(
- sql.text("SELECT gid FROM pg_prepared_xacts"))
+ sql.text("SELECT gid FROM pg_prepared_xacts")
+ )
return [row[0] for row in resultset]
def _get_default_schema_name(self, connection):
return connection.scalar("select current_schema()")
def has_schema(self, connection, schema):
- query = ("select nspname from pg_namespace "
- "where lower(nspname)=:schema")
+ query = (
+ "select nspname from pg_namespace " "where lower(nspname)=:schema"
+ )
cursor = connection.execute(
sql.text(
query,
bindparams=[
sql.bindparam(
- 'schema', util.text_type(schema.lower()),
- type_=sqltypes.Unicode)]
+ "schema",
+ util.text_type(schema.lower()),
+ type_=sqltypes.Unicode,
+ )
+ ],
)
)
@@ -2349,8 +2531,12 @@ class PGDialect(default.DefaultDialect):
"pg_catalog.pg_table_is_visible(c.oid) "
"and relname=:name",
bindparams=[
- sql.bindparam('name', util.text_type(table_name),
- type_=sqltypes.Unicode)]
+ sql.bindparam(
+ "name",
+ util.text_type(table_name),
+ type_=sqltypes.Unicode,
+ )
+ ],
)
)
else:
@@ -2360,12 +2546,17 @@ class PGDialect(default.DefaultDialect):
"n.oid=c.relnamespace where n.nspname=:schema and "
"relname=:name",
bindparams=[
- sql.bindparam('name',
- util.text_type(table_name),
- type_=sqltypes.Unicode),
- sql.bindparam('schema',
- util.text_type(schema),
- type_=sqltypes.Unicode)]
+ sql.bindparam(
+ "name",
+ util.text_type(table_name),
+ type_=sqltypes.Unicode,
+ ),
+ sql.bindparam(
+ "schema",
+ util.text_type(schema),
+ type_=sqltypes.Unicode,
+ ),
+ ],
)
)
return bool(cursor.first())
@@ -2379,9 +2570,12 @@ class PGDialect(default.DefaultDialect):
"n.nspname=current_schema() "
"and relname=:name",
bindparams=[
- sql.bindparam('name', util.text_type(sequence_name),
- type_=sqltypes.Unicode)
- ]
+ sql.bindparam(
+ "name",
+ util.text_type(sequence_name),
+ type_=sqltypes.Unicode,
+ )
+ ],
)
)
else:
@@ -2391,12 +2585,17 @@ class PGDialect(default.DefaultDialect):
"n.oid=c.relnamespace where relkind='S' and "
"n.nspname=:schema and relname=:name",
bindparams=[
- sql.bindparam('name', util.text_type(sequence_name),
- type_=sqltypes.Unicode),
- sql.bindparam('schema',
- util.text_type(schema),
- type_=sqltypes.Unicode)
- ]
+ sql.bindparam(
+ "name",
+ util.text_type(sequence_name),
+ type_=sqltypes.Unicode,
+ ),
+ sql.bindparam(
+ "schema",
+ util.text_type(schema),
+ type_=sqltypes.Unicode,
+ ),
+ ],
)
)
@@ -2423,13 +2622,15 @@ class PGDialect(default.DefaultDialect):
"""
query = sql.text(query)
query = query.bindparams(
- sql.bindparam('typname',
- util.text_type(type_name), type_=sqltypes.Unicode),
+ sql.bindparam(
+ "typname", util.text_type(type_name), type_=sqltypes.Unicode
+ )
)
if schema is not None:
query = query.bindparams(
- sql.bindparam('nspname',
- util.text_type(schema), type_=sqltypes.Unicode),
+ sql.bindparam(
+ "nspname", util.text_type(schema), type_=sqltypes.Unicode
+ )
)
cursor = connection.execute(query)
return bool(cursor.scalar())
@@ -2437,12 +2638,14 @@ class PGDialect(default.DefaultDialect):
def _get_server_version_info(self, connection):
v = connection.execute("select version()").scalar()
m = re.match(
- r'.*(?:PostgreSQL|EnterpriseDB) '
- r'(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?',
- v)
+ r".*(?:PostgreSQL|EnterpriseDB) "
+ r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?",
+ v,
+ )
if not m:
raise AssertionError(
- "Could not determine version from string '%s'" % v)
+ "Could not determine version from string '%s'" % v
+ )
return tuple([int(x) for x in m.group(1, 2, 3) if x is not None])
@reflection.cache
@@ -2459,14 +2662,17 @@ class PGDialect(default.DefaultDialect):
schema_where_clause = "n.nspname = :schema"
else:
schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
- query = """
+ query = (
+ """
SELECT c.oid
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE (%s)
AND c.relname = :table_name AND c.relkind in
('r', 'v', 'm', 'f', 'p')
- """ % schema_where_clause
+ """
+ % schema_where_clause
+ )
# Since we're binding to unicode, table_name and schema_name must be
# unicode.
table_name = util.text_type(table_name)
@@ -2475,7 +2681,7 @@ class PGDialect(default.DefaultDialect):
s = sql.text(query).bindparams(table_name=sqltypes.Unicode)
s = s.columns(oid=sqltypes.Integer)
if schema:
- s = s.bindparams(sql.bindparam('schema', type_=sqltypes.Unicode))
+ s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode))
c = connection.execute(s, table_name=table_name, schema=schema)
table_oid = c.scalar()
if table_oid is None:
@@ -2485,75 +2691,88 @@ class PGDialect(default.DefaultDialect):
@reflection.cache
def get_schema_names(self, connection, **kw):
result = connection.execute(
- sql.text("SELECT nspname FROM pg_namespace "
- "WHERE nspname NOT LIKE 'pg_%' "
- "ORDER BY nspname"
- ).columns(nspname=sqltypes.Unicode))
+ sql.text(
+ "SELECT nspname FROM pg_namespace "
+ "WHERE nspname NOT LIKE 'pg_%' "
+ "ORDER BY nspname"
+ ).columns(nspname=sqltypes.Unicode)
+ )
return [name for name, in result]
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
result = connection.execute(
- sql.text("SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
- ).columns(relname=sqltypes.Unicode),
- schema=schema if schema is not None else self.default_schema_name)
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
+ ).columns(relname=sqltypes.Unicode),
+ schema=schema if schema is not None else self.default_schema_name,
+ )
return [name for name, in result]
@reflection.cache
def _get_foreign_table_names(self, connection, schema=None, **kw):
result = connection.execute(
- sql.text("SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind = 'f'"
- ).columns(relname=sqltypes.Unicode),
- schema=schema if schema is not None else self.default_schema_name)
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind = 'f'"
+ ).columns(relname=sqltypes.Unicode),
+ schema=schema if schema is not None else self.default_schema_name,
+ )
return [name for name, in result]
@reflection.cache
def get_view_names(
- self, connection, schema=None,
- include=('plain', 'materialized'), **kw):
+ self, connection, schema=None, include=("plain", "materialized"), **kw
+ ):
- include_kind = {'plain': 'v', 'materialized': 'm'}
+ include_kind = {"plain": "v", "materialized": "m"}
try:
kinds = [include_kind[i] for i in util.to_list(include)]
except KeyError:
raise ValueError(
"include %r unknown, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'" % (include,))
+ "one or both of 'plain' and 'materialized'" % (include,)
+ )
if not kinds:
raise ValueError(
"empty include, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'")
+ "one or both of 'plain' and 'materialized'"
+ )
result = connection.execute(
- sql.text("SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind IN (%s)" %
- (", ".join("'%s'" % elem for elem in kinds))
- ).columns(relname=sqltypes.Unicode),
- schema=schema if schema is not None else self.default_schema_name)
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind IN (%s)"
+ % (", ".join("'%s'" % elem for elem in kinds))
+ ).columns(relname=sqltypes.Unicode),
+ schema=schema if schema is not None else self.default_schema_name,
+ )
return [name for name, in result]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
view_def = connection.scalar(
- sql.text("SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relname = :view_name "
- "AND c.relkind IN ('v', 'm')"
- ).columns(view_def=sqltypes.Unicode),
+ sql.text(
+ "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relname = :view_name "
+ "AND c.relkind IN ('v', 'm')"
+ ).columns(view_def=sqltypes.Unicode),
schema=schema if schema is not None else self.default_schema_name,
- view_name=view_name)
+ view_name=view_name,
+ )
return view_def
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(connection, table_name, schema,
- info_cache=kw.get('info_cache'))
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
SQL_COLS = """
SELECT a.attname,
pg_catalog.format_type(a.atttypid, a.atttypmod),
@@ -2571,13 +2790,11 @@ class PGDialect(default.DefaultDialect):
AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum
"""
- s = sql.text(SQL_COLS,
- bindparams=[
- sql.bindparam('table_oid', type_=sqltypes.Integer)],
- typemap={
- 'attname': sqltypes.Unicode,
- 'default': sqltypes.Unicode}
- )
+ s = sql.text(
+ SQL_COLS,
+ bindparams=[sql.bindparam("table_oid", type_=sqltypes.Integer)],
+ typemap={"attname": sqltypes.Unicode, "default": sqltypes.Unicode},
+ )
c = connection.execute(s, table_oid=table_oid)
rows = c.fetchall()
@@ -2588,34 +2805,58 @@ class PGDialect(default.DefaultDialect):
# dictionary with (name, ) if default search path or (schema, name)
# as keys
enums = dict(
- ((rec['name'], ), rec)
- if rec['visible'] else ((rec['schema'], rec['name']), rec)
- for rec in self._load_enums(connection, schema='*')
+ ((rec["name"],), rec)
+ if rec["visible"]
+ else ((rec["schema"], rec["name"]), rec)
+ for rec in self._load_enums(connection, schema="*")
)
# format columns
columns = []
- for name, format_type, default_, notnull, attnum, table_oid, \
- comment in rows:
+ for (
+ name,
+ format_type,
+ default_,
+ notnull,
+ attnum,
+ table_oid,
+ comment,
+ ) in rows:
column_info = self._get_column_info(
- name, format_type, default_, notnull, domains, enums,
- schema, comment)
+ name,
+ format_type,
+ default_,
+ notnull,
+ domains,
+ enums,
+ schema,
+ comment,
+ )
columns.append(column_info)
return columns
- def _get_column_info(self, name, format_type, default,
- notnull, domains, enums, schema, comment):
+ def _get_column_info(
+ self,
+ name,
+ format_type,
+ default,
+ notnull,
+ domains,
+ enums,
+ schema,
+ comment,
+ ):
def _handle_array_type(attype):
return (
# strip '[]' from integer[], etc.
- re.sub(r'\[\]$', '', attype),
- attype.endswith('[]'),
+ re.sub(r"\[\]$", "", attype),
+ attype.endswith("[]"),
)
# strip (*) from character varying(5), timestamp(5)
# with time zone, geometry(POLYGON), etc.
- attype = re.sub(r'\(.*\)', '', format_type)
+ attype = re.sub(r"\(.*\)", "", format_type)
# strip '[]' from integer[], etc. and check if an array
attype, is_array = _handle_array_type(attype)
@@ -2625,50 +2866,52 @@ class PGDialect(default.DefaultDialect):
nullable = not notnull
- charlen = re.search(r'\(([\d,]+)\)', format_type)
+ charlen = re.search(r"\(([\d,]+)\)", format_type)
if charlen:
charlen = charlen.group(1)
- args = re.search(r'\((.*)\)', format_type)
+ args = re.search(r"\((.*)\)", format_type)
if args and args.group(1):
- args = tuple(re.split(r'\s*,\s*', args.group(1)))
+ args = tuple(re.split(r"\s*,\s*", args.group(1)))
else:
args = ()
kwargs = {}
- if attype == 'numeric':
+ if attype == "numeric":
if charlen:
- prec, scale = charlen.split(',')
+ prec, scale = charlen.split(",")
args = (int(prec), int(scale))
else:
args = ()
- elif attype == 'double precision':
- args = (53, )
- elif attype == 'integer':
+ elif attype == "double precision":
+ args = (53,)
+ elif attype == "integer":
args = ()
- elif attype in ('timestamp with time zone',
- 'time with time zone'):
- kwargs['timezone'] = True
+ elif attype in ("timestamp with time zone", "time with time zone"):
+ kwargs["timezone"] = True
if charlen:
- kwargs['precision'] = int(charlen)
+ kwargs["precision"] = int(charlen)
args = ()
- elif attype in ('timestamp without time zone',
- 'time without time zone', 'time'):
- kwargs['timezone'] = False
+ elif attype in (
+ "timestamp without time zone",
+ "time without time zone",
+ "time",
+ ):
+ kwargs["timezone"] = False
if charlen:
- kwargs['precision'] = int(charlen)
+ kwargs["precision"] = int(charlen)
args = ()
- elif attype == 'bit varying':
- kwargs['varying'] = True
+ elif attype == "bit varying":
+ kwargs["varying"] = True
if charlen:
args = (int(charlen),)
else:
args = ()
- elif attype.startswith('interval'):
- field_match = re.match(r'interval (.+)', attype, re.I)
+ elif attype.startswith("interval"):
+ field_match = re.match(r"interval (.+)", attype, re.I)
if charlen:
- kwargs['precision'] = int(charlen)
+ kwargs["precision"] = int(charlen)
if field_match:
- kwargs['fields'] = field_match.group(1)
+ kwargs["fields"] = field_match.group(1)
attype = "interval"
args = ()
elif charlen:
@@ -2682,23 +2925,23 @@ class PGDialect(default.DefaultDialect):
elif enum_or_domain_key in enums:
enum = enums[enum_or_domain_key]
coltype = ENUM
- kwargs['name'] = enum['name']
- if not enum['visible']:
- kwargs['schema'] = enum['schema']
- args = tuple(enum['labels'])
+ kwargs["name"] = enum["name"]
+ if not enum["visible"]:
+ kwargs["schema"] = enum["schema"]
+ args = tuple(enum["labels"])
break
elif enum_or_domain_key in domains:
domain = domains[enum_or_domain_key]
- attype = domain['attype']
+ attype = domain["attype"]
attype, is_array = _handle_array_type(attype)
# strip quotes from case sensitive enum or domain names
enum_or_domain_key = tuple(util.quoted_token_parser(attype))
# A table can't override whether the domain is nullable.
- nullable = domain['nullable']
- if domain['default'] and not default:
+ nullable = domain["nullable"]
+ if domain["default"] and not default:
# It can, however, override the default
# value, but can't set it to null.
- default = domain['default']
+ default = domain["default"]
continue
else:
coltype = None
@@ -2707,10 +2950,11 @@ class PGDialect(default.DefaultDialect):
if coltype:
coltype = coltype(*args, **kwargs)
if is_array:
- coltype = self.ischema_names['_array'](coltype)
+ coltype = self.ischema_names["_array"](coltype)
else:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (attype, name))
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (attype, name)
+ )
coltype = sqltypes.NULLTYPE
# adjust the default value
autoincrement = False
@@ -2721,23 +2965,33 @@ class PGDialect(default.DefaultDialect):
autoincrement = True
# the default is related to a Sequence
sch = schema
- if '.' not in match.group(2) and sch is not None:
+ if "." not in match.group(2) and sch is not None:
# unconditionally quote the schema name. this could
# later be enhanced to obey quoting rules /
# "quote schema"
- default = match.group(1) + \
- ('"%s"' % sch) + '.' + \
- match.group(2) + match.group(3)
+ default = (
+ match.group(1)
+ + ('"%s"' % sch)
+ + "."
+ + match.group(2)
+ + match.group(3)
+ )
- column_info = dict(name=name, type=coltype, nullable=nullable,
- default=default, autoincrement=autoincrement,
- comment=comment)
+ column_info = dict(
+ name=name,
+ type=coltype,
+ nullable=nullable,
+ default=default,
+ autoincrement=autoincrement,
+ comment=comment,
+ )
return column_info
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(connection, table_name, schema,
- info_cache=kw.get('info_cache'))
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
if self.server_version_info < (8, 4):
PK_SQL = """
@@ -2750,7 +3004,9 @@ class PGDialect(default.DefaultDialect):
WHERE
t.oid = :table_oid and ix.indisprimary = 't'
ORDER BY a.attnum
- """ % self._pg_index_any("a.attnum", "ix.indkey")
+ """ % self._pg_index_any(
+ "a.attnum", "ix.indkey"
+ )
else:
# unnest() and generate_subscripts() both introduced in
@@ -2766,7 +3022,7 @@ class PGDialect(default.DefaultDialect):
WHERE a.attrelid = :table_oid
ORDER BY k.ord
"""
- t = sql.text(PK_SQL, typemap={'attname': sqltypes.Unicode})
+ t = sql.text(PK_SQL, typemap={"attname": sqltypes.Unicode})
c = connection.execute(t, table_oid=table_oid)
cols = [r[0] for r in c.fetchall()]
@@ -2776,18 +3032,25 @@ class PGDialect(default.DefaultDialect):
WHERE r.conrelid = :table_oid AND r.contype = 'p'
ORDER BY 1
"""
- t = sql.text(PK_CONS_SQL, typemap={'conname': sqltypes.Unicode})
+ t = sql.text(PK_CONS_SQL, typemap={"conname": sqltypes.Unicode})
c = connection.execute(t, table_oid=table_oid)
name = c.scalar()
- return {'constrained_columns': cols, 'name': name}
+ return {"constrained_columns": cols, "name": name}
@reflection.cache
- def get_foreign_keys(self, connection, table_name, schema=None,
- postgresql_ignore_search_path=False, **kw):
+ def get_foreign_keys(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ postgresql_ignore_search_path=False,
+ **kw
+ ):
preparer = self.identifier_preparer
- table_oid = self.get_table_oid(connection, table_name, schema,
- info_cache=kw.get('info_cache'))
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
FK_SQL = """
SELECT r.conname,
@@ -2805,34 +3068,35 @@ class PGDialect(default.DefaultDialect):
"""
# http://www.postgresql.org/docs/9.0/static/sql-createtable.html
FK_REGEX = re.compile(
- r'FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)'
- r'[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?'
- r'[\s]?(ON UPDATE '
- r'(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?'
- r'[\s]?(ON DELETE '
- r'(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?'
- r'[\s]?(DEFERRABLE|NOT DEFERRABLE)?'
- r'[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?'
+ r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
+ r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
+ r"[\s]?(ON UPDATE "
+ r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
+ r"[\s]?(ON DELETE "
+ r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
+ r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?"
+ r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
)
- t = sql.text(FK_SQL, typemap={
- 'conname': sqltypes.Unicode,
- 'condef': sqltypes.Unicode})
+ t = sql.text(
+ FK_SQL,
+ typemap={"conname": sqltypes.Unicode, "condef": sqltypes.Unicode},
+ )
c = connection.execute(t, table=table_oid)
fkeys = []
for conname, condef, conschema in c.fetchall():
m = re.search(FK_REGEX, condef).groups()
- constrained_columns, referred_schema, \
- referred_table, referred_columns, \
- _, match, _, onupdate, _, ondelete, \
- deferrable, _, initially = m
+ constrained_columns, referred_schema, referred_table, referred_columns, _, match, _, onupdate, _, ondelete, deferrable, _, initially = (
+ m
+ )
if deferrable is not None:
- deferrable = True if deferrable == 'DEFERRABLE' else False
- constrained_columns = [preparer._unquote_identifier(x)
- for x in re.split(
- r'\s*,\s*', constrained_columns)]
+ deferrable = True if deferrable == "DEFERRABLE" else False
+ constrained_columns = [
+ preparer._unquote_identifier(x)
+ for x in re.split(r"\s*,\s*", constrained_columns)
+ ]
if postgresql_ignore_search_path:
# when ignoring search path, we use the actual schema
@@ -2845,30 +3109,30 @@ class PGDialect(default.DefaultDialect):
# referred_schema is the schema that we regexp'ed from
# pg_get_constraintdef(). If the schema is in the search
# path, pg_get_constraintdef() will give us None.
- referred_schema = \
- preparer._unquote_identifier(referred_schema)
+ referred_schema = preparer._unquote_identifier(referred_schema)
elif schema is not None and schema == conschema:
# If the actual schema matches the schema of the table
# we're reflecting, then we will use that.
referred_schema = schema
referred_table = preparer._unquote_identifier(referred_table)
- referred_columns = [preparer._unquote_identifier(x)
- for x in
- re.split(r'\s*,\s', referred_columns)]
+ referred_columns = [
+ preparer._unquote_identifier(x)
+ for x in re.split(r"\s*,\s", referred_columns)
+ ]
fkey_d = {
- 'name': conname,
- 'constrained_columns': constrained_columns,
- 'referred_schema': referred_schema,
- 'referred_table': referred_table,
- 'referred_columns': referred_columns,
- 'options': {
- 'onupdate': onupdate,
- 'ondelete': ondelete,
- 'deferrable': deferrable,
- 'initially': initially,
- 'match': match
- }
+ "name": conname,
+ "constrained_columns": constrained_columns,
+ "referred_schema": referred_schema,
+ "referred_table": referred_table,
+ "referred_columns": referred_columns,
+ "options": {
+ "onupdate": onupdate,
+ "ondelete": ondelete,
+ "deferrable": deferrable,
+ "initially": initially,
+ "match": match,
+ },
}
fkeys.append(fkey_d)
return fkeys
@@ -2882,16 +3146,16 @@ class PGDialect(default.DefaultDialect):
# for now.
# regards, tom lane"
return "(%s)" % " OR ".join(
- "%s[%d] = %s" % (compare_to, ind, col)
- for ind in range(0, 10)
+ "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10)
)
else:
return "%s = ANY(%s)" % (col, compare_to)
@reflection.cache
def get_indexes(self, connection, table_name, schema, **kw):
- table_oid = self.get_table_oid(connection, table_name, schema,
- info_cache=kw.get('info_cache'))
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
# cast indkey as varchar since it's an int2vector,
# returned as a list by some drivers such as pypostgresql
@@ -2925,9 +3189,10 @@ class PGDialect(default.DefaultDialect):
# cast does not work in PG 8.2.4, does work in 8.3.0.
# nothing in PG changelogs regarding this.
"::varchar" if self.server_version_info >= (8, 3) else "",
- "i.reloptions" if self.server_version_info >= (8, 2)
+ "i.reloptions"
+ if self.server_version_info >= (8, 2)
else "NULL",
- self._pg_index_any("a.attnum", "ix.indkey")
+ self._pg_index_any("a.attnum", "ix.indkey"),
)
else:
IDX_SQL = """
@@ -2960,76 +3225,93 @@ class PGDialect(default.DefaultDialect):
i.relname
"""
- t = sql.text(IDX_SQL, typemap={
- 'relname': sqltypes.Unicode,
- 'attname': sqltypes.Unicode})
+ t = sql.text(
+ IDX_SQL,
+ typemap={"relname": sqltypes.Unicode, "attname": sqltypes.Unicode},
+ )
c = connection.execute(t, table_oid=table_oid)
indexes = defaultdict(lambda: defaultdict(dict))
sv_idx_name = None
for row in c.fetchall():
- (idx_name, unique, expr, prd, col,
- col_num, conrelid, idx_key, options, amname) = row
+ (
+ idx_name,
+ unique,
+ expr,
+ prd,
+ col,
+ col_num,
+ conrelid,
+ idx_key,
+ options,
+ amname,
+ ) = row
if expr:
if idx_name != sv_idx_name:
util.warn(
"Skipped unsupported reflection of "
- "expression-based index %s"
- % idx_name)
+ "expression-based index %s" % idx_name
+ )
sv_idx_name = idx_name
continue
if prd and not idx_name == sv_idx_name:
util.warn(
"Predicate of partial index %s ignored during reflection"
- % idx_name)
+ % idx_name
+ )
sv_idx_name = idx_name
has_idx = idx_name in indexes
index = indexes[idx_name]
if col is not None:
- index['cols'][col_num] = col
+ index["cols"][col_num] = col
if not has_idx:
- index['key'] = [int(k.strip()) for k in idx_key.split()]
- index['unique'] = unique
+ index["key"] = [int(k.strip()) for k in idx_key.split()]
+ index["unique"] = unique
if conrelid is not None:
- index['duplicates_constraint'] = idx_name
+ index["duplicates_constraint"] = idx_name
if options:
- index['options'] = dict(
- [option.split("=") for option in options])
+ index["options"] = dict(
+ [option.split("=") for option in options]
+ )
# it *might* be nice to include that this is 'btree' in the
# reflection info. But we don't want an Index object
# to have a ``postgresql_using`` in it that is just the
# default, so for the moment leaving this out.
- if amname and amname != 'btree':
- index['amname'] = amname
+ if amname and amname != "btree":
+ index["amname"] = amname
result = []
for name, idx in indexes.items():
entry = {
- 'name': name,
- 'unique': idx['unique'],
- 'column_names': [idx['cols'][i] for i in idx['key']]
+ "name": name,
+ "unique": idx["unique"],
+ "column_names": [idx["cols"][i] for i in idx["key"]],
}
- if 'duplicates_constraint' in idx:
- entry['duplicates_constraint'] = idx['duplicates_constraint']
- if 'options' in idx:
- entry.setdefault(
- 'dialect_options', {})["postgresql_with"] = idx['options']
- if 'amname' in idx:
- entry.setdefault(
- 'dialect_options', {})["postgresql_using"] = idx['amname']
+ if "duplicates_constraint" in idx:
+ entry["duplicates_constraint"] = idx["duplicates_constraint"]
+ if "options" in idx:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_with"
+ ] = idx["options"]
+ if "amname" in idx:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_using"
+ ] = idx["amname"]
result.append(entry)
return result
@reflection.cache
- def get_unique_constraints(self, connection, table_name,
- schema=None, **kw):
- table_oid = self.get_table_oid(connection, table_name, schema,
- info_cache=kw.get('info_cache'))
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
UNIQUE_SQL = """
SELECT
@@ -3047,7 +3329,7 @@ class PGDialect(default.DefaultDialect):
cons.contype = 'u'
"""
- t = sql.text(UNIQUE_SQL, typemap={'col_name': sqltypes.Unicode})
+ t = sql.text(UNIQUE_SQL, typemap={"col_name": sqltypes.Unicode})
c = connection.execute(t, table_oid=table_oid)
uniques = defaultdict(lambda: defaultdict(dict))
@@ -3057,15 +3339,15 @@ class PGDialect(default.DefaultDialect):
uc["cols"][row.col_num] = row.col_name
return [
- {'name': name,
- 'column_names': [uc["cols"][i] for i in uc["key"]]}
+ {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]}
for name, uc in uniques.items()
]
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(connection, table_name, schema,
- info_cache=kw.get('info_cache'))
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
COMMENT_SQL = """
SELECT
@@ -3081,10 +3363,10 @@ class PGDialect(default.DefaultDialect):
return {"text": c.scalar()}
@reflection.cache
- def get_check_constraints(
- self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(connection, table_name, schema,
- info_cache=kw.get('info_cache'))
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
CHECK_SQL = """
SELECT
@@ -3100,10 +3382,8 @@ class PGDialect(default.DefaultDialect):
c = connection.execute(sql.text(CHECK_SQL), table_oid=table_oid)
return [
- {'name': name,
- 'sqltext': src[1:-1]}
- for name, src in c.fetchall()
- ]
+ {"name": name, "sqltext": src[1:-1]} for name, src in c.fetchall()
+ ]
def _load_enums(self, connection, schema=None):
schema = schema or self.default_schema_name
@@ -3124,17 +3404,18 @@ class PGDialect(default.DefaultDialect):
WHERE t.typtype = 'e'
"""
- if schema != '*':
+ if schema != "*":
SQL_ENUMS += "AND n.nspname = :schema "
# e.oid gives us label order within an enum
SQL_ENUMS += 'ORDER BY "schema", "name", e.oid'
- s = sql.text(SQL_ENUMS, typemap={
- 'attname': sqltypes.Unicode,
- 'label': sqltypes.Unicode})
+ s = sql.text(
+ SQL_ENUMS,
+ typemap={"attname": sqltypes.Unicode, "label": sqltypes.Unicode},
+ )
- if schema != '*':
+ if schema != "*":
s = s.bindparams(schema=schema)
c = connection.execute(s)
@@ -3142,15 +3423,15 @@ class PGDialect(default.DefaultDialect):
enums = []
enum_by_name = {}
for enum in c.fetchall():
- key = (enum['schema'], enum['name'])
+ key = (enum["schema"], enum["name"])
if key in enum_by_name:
- enum_by_name[key]['labels'].append(enum['label'])
+ enum_by_name[key]["labels"].append(enum["label"])
else:
enum_by_name[key] = enum_rec = {
- 'name': enum['name'],
- 'schema': enum['schema'],
- 'visible': enum['visible'],
- 'labels': [enum['label']],
+ "name": enum["name"],
+ "schema": enum["schema"],
+ "visible": enum["visible"],
+ "labels": [enum["label"]],
}
enums.append(enum_rec)
return enums
@@ -3169,26 +3450,26 @@ class PGDialect(default.DefaultDialect):
WHERE t.typtype = 'd'
"""
- s = sql.text(SQL_DOMAINS, typemap={'attname': sqltypes.Unicode})
+ s = sql.text(SQL_DOMAINS, typemap={"attname": sqltypes.Unicode})
c = connection.execute(s)
domains = {}
for domain in c.fetchall():
# strip (30) from character varying(30)
- attype = re.search(r'([^\(]+)', domain['attype']).group(1)
+ attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
# 'visible' just means whether or not the domain is in a
# schema that's on the search path -- or not overridden by
# a schema with higher precedence. If it's not visible,
# it will be prefixed with the schema-name when it's used.
- if domain['visible']:
- key = (domain['name'], )
+ if domain["visible"]:
+ key = (domain["name"],)
else:
- key = (domain['schema'], domain['name'])
+ key = (domain["schema"], domain["name"])
domains[key] = {
- 'attype': attype,
- 'nullable': domain['nullable'],
- 'default': domain['default']
+ "attype": attype,
+ "nullable": domain["nullable"],
+ "default": domain["default"],
}
return domains
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
index 555a9006c..825f13238 100644
--- a/lib/sqlalchemy/dialects/postgresql/dml.py
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -14,7 +14,7 @@ from ...sql.base import _generative
from ... import util
from . import ext
-__all__ = ('Insert', 'insert')
+__all__ = ("Insert", "insert")
class Insert(StandardInsert):
@@ -40,13 +40,17 @@ class Insert(StandardInsert):
to use :attr:`.Insert.excluded`
"""
- return alias(self.table, name='excluded').columns
+ return alias(self.table, name="excluded").columns
@_generative
def on_conflict_do_update(
- self,
- constraint=None, index_elements=None,
- index_where=None, set_=None, where=None):
+ self,
+ constraint=None,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
"""
Specifies a DO UPDATE SET action for ON CONFLICT clause.
@@ -96,13 +100,14 @@ class Insert(StandardInsert):
"""
self._post_values_clause = OnConflictDoUpdate(
- constraint, index_elements, index_where, set_, where)
+ constraint, index_elements, index_where, set_, where
+ )
return self
@_generative
def on_conflict_do_nothing(
- self,
- constraint=None, index_elements=None, index_where=None):
+ self, constraint=None, index_elements=None, index_where=None
+ ):
"""
Specifies a DO NOTHING action for ON CONFLICT clause.
@@ -130,30 +135,29 @@ class Insert(StandardInsert):
"""
self._post_values_clause = OnConflictDoNothing(
- constraint, index_elements, index_where)
+ constraint, index_elements, index_where
+ )
return self
-insert = public_factory(Insert, '.dialects.postgresql.insert')
+
+insert = public_factory(Insert, ".dialects.postgresql.insert")
class OnConflictClause(ClauseElement):
- def __init__(
- self,
- constraint=None,
- index_elements=None,
- index_where=None):
+ def __init__(self, constraint=None, index_elements=None, index_where=None):
if constraint is not None:
- if not isinstance(constraint, util.string_types) and \
- isinstance(constraint, (
- schema.Index, schema.Constraint,
- ext.ExcludeConstraint)):
- constraint = getattr(constraint, 'name') or constraint
+ if not isinstance(constraint, util.string_types) and isinstance(
+ constraint,
+ (schema.Index, schema.Constraint, ext.ExcludeConstraint),
+ ):
+ constraint = getattr(constraint, "name") or constraint
if constraint is not None:
if index_elements is not None:
raise ValueError(
- "'constraint' and 'index_elements' are mutually exclusive")
+ "'constraint' and 'index_elements' are mutually exclusive"
+ )
if isinstance(constraint, util.string_types):
self.constraint_target = constraint
@@ -161,54 +165,61 @@ class OnConflictClause(ClauseElement):
self.inferred_target_whereclause = None
elif isinstance(constraint, schema.Index):
index_elements = constraint.expressions
- index_where = \
- constraint.dialect_options['postgresql'].get("where")
+ index_where = constraint.dialect_options["postgresql"].get(
+ "where"
+ )
elif isinstance(constraint, ext.ExcludeConstraint):
index_elements = constraint.columns
index_where = constraint.where
else:
index_elements = constraint.columns
- index_where = \
- constraint.dialect_options['postgresql'].get("where")
+ index_where = constraint.dialect_options["postgresql"].get(
+ "where"
+ )
if index_elements is not None:
self.constraint_target = None
self.inferred_target_elements = index_elements
self.inferred_target_whereclause = index_where
elif constraint is None:
- self.constraint_target = self.inferred_target_elements = \
- self.inferred_target_whereclause = None
+ self.constraint_target = (
+ self.inferred_target_elements
+ ) = self.inferred_target_whereclause = None
class OnConflictDoNothing(OnConflictClause):
- __visit_name__ = 'on_conflict_do_nothing'
+ __visit_name__ = "on_conflict_do_nothing"
class OnConflictDoUpdate(OnConflictClause):
- __visit_name__ = 'on_conflict_do_update'
+ __visit_name__ = "on_conflict_do_update"
def __init__(
- self,
- constraint=None,
- index_elements=None,
- index_where=None,
- set_=None,
- where=None):
+ self,
+ constraint=None,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
super(OnConflictDoUpdate, self).__init__(
constraint=constraint,
index_elements=index_elements,
- index_where=index_where)
+ index_where=index_where,
+ )
- if self.inferred_target_elements is None and \
- self.constraint_target is None:
+ if (
+ self.inferred_target_elements is None
+ and self.constraint_target is None
+ ):
raise ValueError(
"Either constraint or index_elements, "
- "but not both, must be specified unless DO NOTHING")
+ "but not both, must be specified unless DO NOTHING"
+ )
- if (not isinstance(set_, dict) or not set_):
+ if not isinstance(set_, dict) or not set_:
raise ValueError("set parameter must be a non-empty dictionary")
self.update_values_to_set = [
- (key, value)
- for key, value in set_.items()
+ (key, value) for key, value in set_.items()
]
self.update_whereclause = where
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index a588eafdd..da0c6250c 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -47,7 +47,7 @@ class aggregate_order_by(expression.ColumnElement):
"""
- __visit_name__ = 'aggregate_order_by'
+ __visit_name__ = "aggregate_order_by"
def __init__(self, target, *order_by):
self.target = elements._literal_as_binds(target)
@@ -59,8 +59,8 @@ class aggregate_order_by(expression.ColumnElement):
self.order_by = elements._literal_as_binds(order_by[0])
else:
self.order_by = elements.ClauseList(
- *order_by,
- _literal_as_text=elements._literal_as_binds)
+ *order_by, _literal_as_text=elements._literal_as_binds
+ )
def self_group(self, against=None):
return self
@@ -87,7 +87,7 @@ class ExcludeConstraint(ColumnCollectionConstraint):
static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
"""
- __visit_name__ = 'exclude_constraint'
+ __visit_name__ = "exclude_constraint"
where = None
@@ -173,8 +173,7 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
expressions, operators = zip(*elements)
for (expr, column, strname, add_element), operator in zip(
- self._extract_col_expression_collection(expressions),
- operators
+ self._extract_col_expression_collection(expressions), operators
):
if add_element is not None:
columns.append(add_element)
@@ -187,32 +186,31 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
expr = expression._literal_as_text(expr)
- render_exprs.append(
- (expr, name, operator)
- )
+ render_exprs.append((expr, name, operator))
self._render_exprs = render_exprs
ColumnCollectionConstraint.__init__(
self,
*columns,
- name=kw.get('name'),
- deferrable=kw.get('deferrable'),
- initially=kw.get('initially')
+ name=kw.get("name"),
+ deferrable=kw.get("deferrable"),
+ initially=kw.get("initially")
)
- self.using = kw.get('using', 'gist')
- where = kw.get('where')
+ self.using = kw.get("using", "gist")
+ where = kw.get("where")
if where is not None:
self.where = expression._literal_as_text(where)
def copy(self, **kw):
- elements = [(col, self.operators[col])
- for col in self.columns.keys()]
- c = self.__class__(*elements,
- name=self.name,
- deferrable=self.deferrable,
- initially=self.initially,
- where=self.where,
- using=self.using)
+ elements = [(col, self.operators[col]) for col in self.columns.keys()]
+ c = self.__class__(
+ *elements,
+ name=self.name,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ where=self.where,
+ using=self.using
+ )
c.dispatch._update(self.dispatch)
return c
@@ -226,5 +224,5 @@ def array_agg(*arg, **kw):
.. versionadded:: 1.1
"""
- kw['_default_array_type'] = ARRAY
+ kw["_default_array_type"] = ARRAY
return functions.func.array_agg(*arg, **kw)
diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py
index b6c9e7124..e4bac692a 100644
--- a/lib/sqlalchemy/dialects/postgresql/hstore.py
+++ b/lib/sqlalchemy/dialects/postgresql/hstore.py
@@ -14,38 +14,50 @@ from ...sql import functions as sqlfunc
from ...sql import operators
from ... import util
-__all__ = ('HSTORE', 'hstore')
+__all__ = ("HSTORE", "hstore")
idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
GETITEM = operators.custom_op(
- "->", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "->",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
HAS_KEY = operators.custom_op(
- "?", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "?",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
HAS_ALL = operators.custom_op(
- "?&", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "?&",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
HAS_ANY = operators.custom_op(
- "?|", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "?|",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
CONTAINS = operators.custom_op(
- "@>", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "@>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
CONTAINED_BY = operators.custom_op(
- "<@", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "<@",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
@@ -122,7 +134,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
"""
- __visit_name__ = 'HSTORE'
+ __visit_name__ = "HSTORE"
hashable = False
text_type = sqltypes.Text()
@@ -139,7 +151,8 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
self.text_type = text_type
class Comparator(
- sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator):
+ sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator
+ ):
"""Define comparison operations for :class:`.HSTORE`."""
def has_key(self, other):
@@ -169,7 +182,8 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
keys of the argument jsonb expression.
"""
return self.operate(
- CONTAINED_BY, other, result_type=sqltypes.Boolean)
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
def _setup_getitem(self, index):
return GETITEM, index, self.type.text_type
@@ -223,12 +237,15 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
return _serialize_hstore(value).encode(encoding)
else:
return value
+
else:
+
def process(value):
if isinstance(value, dict):
return _serialize_hstore(value)
else:
return value
+
return process
def result_processor(self, dialect, coltype):
@@ -240,16 +257,19 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
return _parse_hstore(value.decode(encoding))
else:
return value
+
else:
+
def process(value):
if value is not None:
return _parse_hstore(value)
else:
return value
+
return process
-ischema_names['hstore'] = HSTORE
+ischema_names["hstore"] = HSTORE
class hstore(sqlfunc.GenericFunction):
@@ -279,43 +299,44 @@ class hstore(sqlfunc.GenericFunction):
:class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype.
"""
+
type = HSTORE
- name = 'hstore'
+ name = "hstore"
class _HStoreDefinedFunction(sqlfunc.GenericFunction):
type = sqltypes.Boolean
- name = 'defined'
+ name = "defined"
class _HStoreDeleteFunction(sqlfunc.GenericFunction):
type = HSTORE
- name = 'delete'
+ name = "delete"
class _HStoreSliceFunction(sqlfunc.GenericFunction):
type = HSTORE
- name = 'slice'
+ name = "slice"
class _HStoreKeysFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
- name = 'akeys'
+ name = "akeys"
class _HStoreValsFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
- name = 'avals'
+ name = "avals"
class _HStoreArrayFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
- name = 'hstore_to_array'
+ name = "hstore_to_array"
class _HStoreMatrixFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
- name = 'hstore_to_matrix'
+ name = "hstore_to_matrix"
#
@@ -326,7 +347,8 @@ class _HStoreMatrixFunction(sqlfunc.GenericFunction):
# My best guess at the parsing rules of hstore literals, since no formal
# grammar is given. This is mostly reverse engineered from PG's input parser
# behavior.
-HSTORE_PAIR_RE = re.compile(r"""
+HSTORE_PAIR_RE = re.compile(
+ r"""
(
"(?P<key> (\\ . | [^"])* )" # Quoted key
)
@@ -335,11 +357,16 @@ HSTORE_PAIR_RE = re.compile(r"""
(?P<value_null> NULL ) # NULL value
| "(?P<value> (\\ . | [^"])* )" # Quoted value
)
-""", re.VERBOSE)
+""",
+ re.VERBOSE,
+)
-HSTORE_DELIMITER_RE = re.compile(r"""
+HSTORE_DELIMITER_RE = re.compile(
+ r"""
[ ]* , [ ]*
-""", re.VERBOSE)
+""",
+ re.VERBOSE,
+)
def _parse_error(hstore_str, pos):
@@ -348,16 +375,19 @@ def _parse_error(hstore_str, pos):
ctx = 20
hslen = len(hstore_str)
- parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)]
- residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)]
+ parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)]
+ residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)]
if len(parsed_tail) > ctx:
- parsed_tail = '[...]' + parsed_tail[1:]
+ parsed_tail = "[...]" + parsed_tail[1:]
if len(residual) > ctx:
- residual = residual[:-1] + '[...]'
+ residual = residual[:-1] + "[...]"
return "After %r, could not parse residual at position %d: %r" % (
- parsed_tail, pos, residual)
+ parsed_tail,
+ pos,
+ residual,
+ )
def _parse_hstore(hstore_str):
@@ -377,13 +407,15 @@ def _parse_hstore(hstore_str):
pair_match = HSTORE_PAIR_RE.match(hstore_str)
while pair_match is not None:
- key = pair_match.group('key').replace(r'\"', '"').replace(
- "\\\\", "\\")
- if pair_match.group('value_null'):
+ key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\")
+ if pair_match.group("value_null"):
value = None
else:
- value = pair_match.group('value').replace(
- r'\"', '"').replace("\\\\", "\\")
+ value = (
+ pair_match.group("value")
+ .replace(r"\"", '"')
+ .replace("\\\\", "\\")
+ )
result[key] = value
pos += pair_match.end()
@@ -405,16 +437,17 @@ def _serialize_hstore(val):
both be strings (except None for values).
"""
+
def esc(s, position):
- if position == 'value' and s is None:
- return 'NULL'
+ if position == "value" and s is None:
+ return "NULL"
elif isinstance(s, util.string_types):
- return '"%s"' % s.replace("\\", "\\\\").replace('"', r'\"')
+ return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"")
else:
- raise ValueError("%r in %s position is not a string." %
- (s, position))
-
- return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value'))
- for k, v in val.items())
-
+ raise ValueError(
+ "%r in %s position is not a string." % (s, position)
+ )
+ return ", ".join(
+ "%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items()
+ )
diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py
index e9256daf3..f9421de37 100644
--- a/lib/sqlalchemy/dialects/postgresql/json.py
+++ b/lib/sqlalchemy/dialects/postgresql/json.py
@@ -12,44 +12,58 @@ from ...sql import operators
from ...sql import elements
from ... import util
-__all__ = ('JSON', 'JSONB')
+__all__ = ("JSON", "JSONB")
idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
ASTEXT = operators.custom_op(
- "->>", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "->>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
JSONPATH_ASTEXT = operators.custom_op(
- "#>>", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "#>>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
HAS_KEY = operators.custom_op(
- "?", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "?",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
HAS_ALL = operators.custom_op(
- "?&", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "?&",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
HAS_ANY = operators.custom_op(
- "?|", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "?|",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
CONTAINS = operators.custom_op(
- "@>", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "@>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
CONTAINED_BY = operators.custom_op(
- "<@", precedence=idx_precedence, natural_self_precedent=True,
- eager_grouping=True
+ "<@",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
)
@@ -59,7 +73,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
def process(value):
assert isinstance(value, util.collections_abc.Sequence)
- tokens = [util.text_type(elem)for elem in value]
+ tokens = [util.text_type(elem) for elem in value]
value = "{%s}" % (", ".join(tokens))
if super_proc:
value = super_proc(value)
@@ -72,7 +86,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
def process(value):
assert isinstance(value, util.collections_abc.Sequence)
- tokens = [util.text_type(elem)for elem in value]
+ tokens = [util.text_type(elem) for elem in value]
value = "{%s}" % (", ".join(tokens))
if super_proc:
value = super_proc(value)
@@ -80,6 +94,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
return process
+
colspecs[sqltypes.JSON.JSONPathType] = JSONPathType
@@ -203,16 +218,19 @@ class JSON(sqltypes.JSON):
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
return self.expr.left.operate(
JSONPATH_ASTEXT,
- self.expr.right, result_type=self.type.astext_type)
+ self.expr.right,
+ result_type=self.type.astext_type,
+ )
else:
return self.expr.left.operate(
- ASTEXT, self.expr.right, result_type=self.type.astext_type)
+ ASTEXT, self.expr.right, result_type=self.type.astext_type
+ )
comparator_factory = Comparator
colspecs[sqltypes.JSON] = JSON
-ischema_names['json'] = JSON
+ischema_names["json"] = JSON
class JSONB(JSON):
@@ -259,7 +277,7 @@ class JSONB(JSON):
"""
- __visit_name__ = 'JSONB'
+ __visit_name__ = "JSONB"
class Comparator(JSON.Comparator):
"""Define comparison operations for :class:`.JSON`."""
@@ -291,8 +309,10 @@ class JSONB(JSON):
keys of the argument jsonb expression.
"""
return self.operate(
- CONTAINED_BY, other, result_type=sqltypes.Boolean)
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
comparator_factory = Comparator
-ischema_names['jsonb'] = JSONB
+
+ischema_names["jsonb"] = JSONB
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
index 80929b808..fef09e0eb 100644
--- a/lib/sqlalchemy/dialects/postgresql/pg8000.py
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -69,8 +69,15 @@ import decimal
from ... import processors
from ... import types as sqltypes
from .base import (
- PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext,
- _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID)
+ PGDialect,
+ PGCompiler,
+ PGIdentifierPreparer,
+ PGExecutionContext,
+ _DECIMAL_TYPES,
+ _FLOAT_TYPES,
+ _INT_TYPES,
+ UUID,
+)
import re
from sqlalchemy.dialects.postgresql.json import JSON
from ...sql.elements import quoted_name
@@ -86,13 +93,15 @@ class _PGNumeric(sqltypes.Numeric):
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
- decimal.Decimal, self._effective_decimal_return_scale)
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# pg8000 returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError(
- "Unknown PG numeric type: %d" % coltype)
+ "Unknown PG numeric type: %d" % coltype
+ )
else:
if coltype in _FLOAT_TYPES:
# pg8000 returns float natively for 701
@@ -101,7 +110,8 @@ class _PGNumeric(sqltypes.Numeric):
return processors.to_float
else:
raise exc.InvalidRequestError(
- "Unknown PG numeric type: %d" % coltype)
+ "Unknown PG numeric type: %d" % coltype
+ )
class _PGNumericNoBind(_PGNumeric):
@@ -110,7 +120,6 @@ class _PGNumericNoBind(_PGNumeric):
class _PGJSON(JSON):
-
def result_processor(self, dialect, coltype):
if dialect._dbapi_version > (1, 10, 1):
return None # Has native JSON
@@ -121,18 +130,22 @@ class _PGJSON(JSON):
class _PGUUID(UUID):
def bind_processor(self, dialect):
if not self.as_uuid:
+
def process(value):
if value is not None:
value = _python_UUID(value)
return value
+
return process
def result_processor(self, dialect, coltype):
if not self.as_uuid:
+
def process(value):
if value is not None:
value = str(value)
return value
+
return process
@@ -142,36 +155,41 @@ class PGExecutionContext_pg8000(PGExecutionContext):
class PGCompiler_pg8000(PGCompiler):
def visit_mod_binary(self, binary, operator, **kw):
- return self.process(binary.left, **kw) + " %% " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
def post_process_text(self, text):
- if '%%' in text:
- util.warn("The SQLAlchemy postgresql dialect "
- "now automatically escapes '%' in text() "
- "expressions to '%%'.")
- return text.replace('%', '%%')
+ if "%%" in text:
+ util.warn(
+ "The SQLAlchemy postgresql dialect "
+ "now automatically escapes '%' in text() "
+ "expressions to '%%'."
+ )
+ return text.replace("%", "%%")
class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
- return value.replace('%', '%%')
+ return value.replace("%", "%%")
class PGDialect_pg8000(PGDialect):
- driver = 'pg8000'
+ driver = "pg8000"
supports_unicode_statements = True
supports_unicode_binds = True
- default_paramstyle = 'format'
+ default_paramstyle = "format"
supports_sane_multi_rowcount = True
execution_ctx_cls = PGExecutionContext_pg8000
statement_compiler = PGCompiler_pg8000
preparer = PGIdentifierPreparer_pg8000
- description_encoding = 'use_encoding'
+ description_encoding = "use_encoding"
colspecs = util.update_copy(
PGDialect.colspecs,
@@ -180,8 +198,8 @@ class PGDialect_pg8000(PGDialect):
sqltypes.Float: _PGNumeric,
JSON: _PGJSON,
sqltypes.JSON: _PGJSON,
- UUID: _PGUUID
- }
+ UUID: _PGUUID,
+ },
)
def __init__(self, client_encoding=None, **kwargs):
@@ -194,22 +212,26 @@ class PGDialect_pg8000(PGDialect):
@util.memoized_property
def _dbapi_version(self):
- if self.dbapi and hasattr(self.dbapi, '__version__'):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
return tuple(
[
- int(x) for x in re.findall(
- r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)])
+ int(x)
+ for x in re.findall(
+ r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
+ )
+ ]
+ )
else:
return (99, 99, 99)
@classmethod
def dbapi(cls):
- return __import__('pg8000')
+ return __import__("pg8000")
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if 'port' in opts:
- opts['port'] = int(opts['port'])
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
opts.update(url.query)
return ([], opts)
@@ -217,32 +239,33 @@ class PGDialect_pg8000(PGDialect):
return "connection is closed" in str(e)
def set_isolation_level(self, connection, level):
- level = level.replace('_', ' ')
+ level = level.replace("_", " ")
# adjust for ConnectionFairy possibly being present
- if hasattr(connection, 'connection'):
+ if hasattr(connection, "connection"):
connection = connection.connection
- if level == 'AUTOCOMMIT':
+ if level == "AUTOCOMMIT":
connection.autocommit = True
elif level in self._isolation_lookup:
connection.autocommit = False
cursor = connection.cursor()
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION "
- "ISOLATION LEVEL %s" % level)
+ "ISOLATION LEVEL %s" % level
+ )
cursor.execute("COMMIT")
cursor.close()
else:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
- "Valid isolation levels for %s are %s or AUTOCOMMIT" %
- (level, self.name, ", ".join(self._isolation_lookup))
+ "Valid isolation levels for %s are %s or AUTOCOMMIT"
+ % (level, self.name, ", ".join(self._isolation_lookup))
)
def set_client_encoding(self, connection, client_encoding):
# adjust for ConnectionFairy possibly being present
- if hasattr(connection, 'connection'):
+ if hasattr(connection, "connection"):
connection = connection.connection
cursor = connection.cursor()
@@ -251,18 +274,20 @@ class PGDialect_pg8000(PGDialect):
cursor.close()
def do_begin_twophase(self, connection, xid):
- connection.connection.tpc_begin((0, xid, ''))
+ connection.connection.tpc_begin((0, xid, ""))
def do_prepare_twophase(self, connection, xid):
connection.connection.tpc_prepare()
def do_rollback_twophase(
- self, connection, xid, is_prepared=True, recover=False):
- connection.connection.tpc_rollback((0, xid, ''))
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ connection.connection.tpc_rollback((0, xid, ""))
def do_commit_twophase(
- self, connection, xid, is_prepared=True, recover=False):
- connection.connection.tpc_commit((0, xid, ''))
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ connection.connection.tpc_commit((0, xid, ""))
def do_recover_twophase(self, connection):
return [row[1] for row in connection.connection.tpc_recover()]
@@ -272,24 +297,32 @@ class PGDialect_pg8000(PGDialect):
def on_connect(conn):
conn.py_types[quoted_name] = conn.py_types[util.text_type]
+
fns.append(on_connect)
if self.client_encoding is not None:
+
def on_connect(conn):
self.set_client_encoding(conn, self.client_encoding)
+
fns.append(on_connect)
if self.isolation_level is not None:
+
def on_connect(conn):
self.set_isolation_level(conn, self.isolation_level)
+
fns.append(on_connect)
if len(fns) > 0:
+
def on_connect(conn):
for fn in fns:
fn(conn)
+
return on_connect
else:
return None
+
dialect = PGDialect_pg8000
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index baa0e00d5..2c27c6919 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -353,10 +353,17 @@ from ... import processors
from ...engine import result as _result
from ...sql import expression
from ... import types as sqltypes
-from .base import PGDialect, PGCompiler, \
- PGIdentifierPreparer, PGExecutionContext, \
- ENUM, _DECIMAL_TYPES, _FLOAT_TYPES,\
- _INT_TYPES, UUID
+from .base import (
+ PGDialect,
+ PGCompiler,
+ PGIdentifierPreparer,
+ PGExecutionContext,
+ ENUM,
+ _DECIMAL_TYPES,
+ _FLOAT_TYPES,
+ _INT_TYPES,
+ UUID,
+)
from .hstore import HSTORE
from .json import JSON, JSONB
@@ -366,7 +373,7 @@ except ImportError:
_python_UUID = None
-logger = logging.getLogger('sqlalchemy.dialects.postgresql')
+logger = logging.getLogger("sqlalchemy.dialects.postgresql")
class _PGNumeric(sqltypes.Numeric):
@@ -377,14 +384,15 @@ class _PGNumeric(sqltypes.Numeric):
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
- decimal.Decimal,
- self._effective_decimal_return_scale)
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# pg8000 returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError(
- "Unknown PG numeric type: %d" % coltype)
+ "Unknown PG numeric type: %d" % coltype
+ )
else:
if coltype in _FLOAT_TYPES:
# pg8000 returns float natively for 701
@@ -393,7 +401,8 @@ class _PGNumeric(sqltypes.Numeric):
return processors.to_float
else:
raise exc.InvalidRequestError(
- "Unknown PG numeric type: %d" % coltype)
+ "Unknown PG numeric type: %d" % coltype
+ )
class _PGEnum(ENUM):
@@ -421,7 +430,6 @@ class _PGHStore(HSTORE):
class _PGJSON(JSON):
-
def result_processor(self, dialect, coltype):
if dialect._has_native_json:
return None
@@ -430,7 +438,6 @@ class _PGJSON(JSON):
class _PGJSONB(JSONB):
-
def result_processor(self, dialect, coltype):
if dialect._has_native_jsonb:
return None
@@ -447,14 +454,17 @@ class _PGUUID(UUID):
if value is not None:
value = _python_UUID(value)
return value
+
return process
def result_processor(self, dialect, coltype):
if not self.as_uuid and dialect.use_native_uuid:
+
def process(value):
if value is not None:
value = str(value)
return value
+
return process
@@ -465,8 +475,7 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
def create_server_side_cursor(self):
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
- ident = "c_%s_%s" % (hex(id(self))[2:],
- hex(_server_side_id())[2:])
+ ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
return self._dbapi_connection.cursor(ident)
def get_result_proxy(self):
@@ -497,13 +506,13 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
class PGDialect_psycopg2(PGDialect):
- driver = 'psycopg2'
+ driver = "psycopg2"
if util.py2k:
supports_unicode_statements = False
supports_server_side_cursors = True
- default_paramstyle = 'pyformat'
+ default_paramstyle = "pyformat"
# set to true based on psycopg2 version
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_psycopg2
@@ -516,16 +525,16 @@ class PGDialect_psycopg2(PGDialect):
native_jsonb=(2, 5, 4),
sane_multi_rowcount=(2, 0, 9),
array_oid=(2, 4, 3),
- hstore_adapter=(2, 4)
+ hstore_adapter=(2, 4),
)
_has_native_hstore = False
_has_native_json = False
_has_native_jsonb = False
- engine_config_types = PGDialect.engine_config_types.union([
- ('use_native_unicode', util.asbool),
- ])
+ engine_config_types = PGDialect.engine_config_types.union(
+ [("use_native_unicode", util.asbool)]
+ )
colspecs = util.update_copy(
PGDialect.colspecs,
@@ -537,15 +546,20 @@ class PGDialect_psycopg2(PGDialect):
JSON: _PGJSON,
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
- UUID: _PGUUID
- }
+ UUID: _PGUUID,
+ },
)
- def __init__(self, server_side_cursors=False, use_native_unicode=True,
- client_encoding=None,
- use_native_hstore=True, use_native_uuid=True,
- use_batch_mode=False,
- **kwargs):
+ def __init__(
+ self,
+ server_side_cursors=False,
+ use_native_unicode=True,
+ client_encoding=None,
+ use_native_hstore=True,
+ use_native_uuid=True,
+ use_batch_mode=False,
+ **kwargs
+ ):
PGDialect.__init__(self, **kwargs)
self.server_side_cursors = server_side_cursors
self.use_native_unicode = use_native_unicode
@@ -554,65 +568,70 @@ class PGDialect_psycopg2(PGDialect):
self.supports_unicode_binds = use_native_unicode
self.client_encoding = client_encoding
self.psycopg2_batch_mode = use_batch_mode
- if self.dbapi and hasattr(self.dbapi, '__version__'):
- m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?',
- self.dbapi.__version__)
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
self.psycopg2_version = tuple(
- int(x)
- for x in m.group(1, 2, 3)
- if x is not None)
+ int(x) for x in m.group(1, 2, 3) if x is not None
+ )
def initialize(self, connection):
super(PGDialect_psycopg2, self).initialize(connection)
- self._has_native_hstore = self.use_native_hstore and \
- self._hstore_oids(connection.connection) \
- is not None
- self._has_native_json = \
- self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_json']
- self._has_native_jsonb = \
- self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_jsonb']
+ self._has_native_hstore = (
+ self.use_native_hstore
+ and self._hstore_oids(connection.connection) is not None
+ )
+ self._has_native_json = (
+ self.psycopg2_version >= self.FEATURE_VERSION_MAP["native_json"]
+ )
+ self._has_native_jsonb = (
+ self.psycopg2_version >= self.FEATURE_VERSION_MAP["native_jsonb"]
+ )
# http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9
- self.supports_sane_multi_rowcount = \
- self.psycopg2_version >= \
- self.FEATURE_VERSION_MAP['sane_multi_rowcount'] and \
- not self.psycopg2_batch_mode
+ self.supports_sane_multi_rowcount = (
+ self.psycopg2_version
+ >= self.FEATURE_VERSION_MAP["sane_multi_rowcount"]
+ and not self.psycopg2_batch_mode
+ )
@classmethod
def dbapi(cls):
import psycopg2
+
return psycopg2
@classmethod
def _psycopg2_extensions(cls):
from psycopg2 import extensions
+
return extensions
@classmethod
def _psycopg2_extras(cls):
from psycopg2 import extras
+
return extras
@util.memoized_property
def _isolation_lookup(self):
extensions = self._psycopg2_extensions()
return {
- 'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT,
- 'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED,
- 'READ UNCOMMITTED': extensions.ISOLATION_LEVEL_READ_UNCOMMITTED,
- 'REPEATABLE READ': extensions.ISOLATION_LEVEL_REPEATABLE_READ,
- 'SERIALIZABLE': extensions.ISOLATION_LEVEL_SERIALIZABLE
+ "AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT,
+ "READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED,
+ "READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED,
+ "REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ,
+ "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE,
}
def set_isolation_level(self, connection, level):
try:
- level = self._isolation_lookup[level.replace('_', ' ')]
+ level = self._isolation_lookup[level.replace("_", " ")]
except KeyError:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
- "Valid isolation levels for %s are %s" %
- (level, self.name, ", ".join(self._isolation_lookup))
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
)
connection.set_isolation_level(level)
@@ -623,54 +642,72 @@ class PGDialect_psycopg2(PGDialect):
fns = []
if self.client_encoding is not None:
+
def on_connect(conn):
conn.set_client_encoding(self.client_encoding)
+
fns.append(on_connect)
if self.isolation_level is not None:
+
def on_connect(conn):
self.set_isolation_level(conn, self.isolation_level)
+
fns.append(on_connect)
if self.dbapi and self.use_native_uuid:
+
def on_connect(conn):
extras.register_uuid(None, conn)
+
fns.append(on_connect)
if self.dbapi and self.use_native_unicode:
+
def on_connect(conn):
extensions.register_type(extensions.UNICODE, conn)
extensions.register_type(extensions.UNICODEARRAY, conn)
+
fns.append(on_connect)
if self.dbapi and self.use_native_hstore:
+
def on_connect(conn):
hstore_oids = self._hstore_oids(conn)
if hstore_oids is not None:
oid, array_oid = hstore_oids
- kw = {'oid': oid}
+ kw = {"oid": oid}
if util.py2k:
- kw['unicode'] = True
- if self.psycopg2_version >= \
- self.FEATURE_VERSION_MAP['array_oid']:
- kw['array_oid'] = array_oid
+ kw["unicode"] = True
+ if (
+ self.psycopg2_version
+ >= self.FEATURE_VERSION_MAP["array_oid"]
+ ):
+ kw["array_oid"] = array_oid
extras.register_hstore(conn, **kw)
+
fns.append(on_connect)
if self.dbapi and self._json_deserializer:
+
def on_connect(conn):
if self._has_native_json:
extras.register_default_json(
- conn, loads=self._json_deserializer)
+ conn, loads=self._json_deserializer
+ )
if self._has_native_jsonb:
extras.register_default_jsonb(
- conn, loads=self._json_deserializer)
+ conn, loads=self._json_deserializer
+ )
+
fns.append(on_connect)
if fns:
+
def on_connect(conn):
for fn in fns:
fn(conn)
+
return on_connect
else:
return None
@@ -684,7 +721,7 @@ class PGDialect_psycopg2(PGDialect):
@util.memoized_instancemethod
def _hstore_oids(self, conn):
- if self.psycopg2_version >= self.FEATURE_VERSION_MAP['hstore_adapter']:
+ if self.psycopg2_version >= self.FEATURE_VERSION_MAP["hstore_adapter"]:
extras = self._psycopg2_extras()
oids = extras.HstoreAdapter.get_oids(conn)
if oids is not None and oids[0]:
@@ -692,9 +729,9 @@ class PGDialect_psycopg2(PGDialect):
return None
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if 'port' in opts:
- opts['port'] = int(opts['port'])
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
opts.update(url.query)
return ([], opts)
@@ -704,7 +741,7 @@ class PGDialect_psycopg2(PGDialect):
# present on old psycopg2 versions. Also,
# this flag doesn't actually help in a lot of disconnect
# situations, so don't rely on it.
- if getattr(connection, 'closed', False):
+ if getattr(connection, "closed", False):
return True
# checks based on strings. in the case that .closed
@@ -713,28 +750,29 @@ class PGDialect_psycopg2(PGDialect):
for msg in [
# these error messages from libpq: interfaces/libpq/fe-misc.c
# and interfaces/libpq/fe-secure.c.
- 'terminating connection',
- 'closed the connection',
- 'connection not open',
- 'could not receive data from server',
- 'could not send data to server',
+ "terminating connection",
+ "closed the connection",
+ "connection not open",
+ "could not receive data from server",
+ "could not send data to server",
# psycopg2 client errors, psycopg2/conenction.h,
# psycopg2/cursor.h
- 'connection already closed',
- 'cursor already closed',
+ "connection already closed",
+ "cursor already closed",
# not sure where this path is originally from, it may
# be obsolete. It really says "losed", not "closed".
- 'losed the connection unexpectedly',
+ "losed the connection unexpectedly",
# these can occur in newer SSL
- 'connection has been closed unexpectedly',
- 'SSL SYSCALL error: Bad file descriptor',
- 'SSL SYSCALL error: EOF detected',
- 'SSL error: decryption failed or bad record mac',
- 'SSL SYSCALL error: Operation timed out',
+ "connection has been closed unexpectedly",
+ "SSL SYSCALL error: Bad file descriptor",
+ "SSL SYSCALL error: EOF detected",
+ "SSL error: decryption failed or bad record mac",
+ "SSL SYSCALL error: Operation timed out",
]:
idx = str_e.find(msg)
if idx >= 0 and '"' not in str_e[:idx]:
return True
return False
+
dialect = PGDialect_psycopg2
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
index a1141a90e..7343bc973 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
@@ -28,7 +28,7 @@ from .psycopg2 import PGDialect_psycopg2
class PGDialect_psycopg2cffi(PGDialect_psycopg2):
- driver = 'psycopg2cffi'
+ driver = "psycopg2cffi"
supports_unicode_statements = True
# psycopg2cffi's first release is 2.5.0, but reports
@@ -40,21 +40,21 @@ class PGDialect_psycopg2cffi(PGDialect_psycopg2):
native_jsonb=(2, 7, 1),
sane_multi_rowcount=(2, 4, 4),
array_oid=(2, 4, 4),
- hstore_adapter=(2, 4, 4)
+ hstore_adapter=(2, 4, 4),
)
@classmethod
def dbapi(cls):
- return __import__('psycopg2cffi')
+ return __import__("psycopg2cffi")
@classmethod
def _psycopg2_extensions(cls):
- root = __import__('psycopg2cffi', fromlist=['extensions'])
+ root = __import__("psycopg2cffi", fromlist=["extensions"])
return root.extensions
@classmethod
def _psycopg2_extras(cls):
- root = __import__('psycopg2cffi', fromlist=['extras'])
+ root = __import__("psycopg2cffi", fromlist=["extras"])
return root.extras
diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py
index 304afca44..c7edb8fc3 100644
--- a/lib/sqlalchemy/dialects/postgresql/pygresql.py
+++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py
@@ -20,14 +20,20 @@ import re
from ... import exc, processors, util
from ...types import Numeric, JSON as Json
from ...sql.elements import Null
-from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \
- _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
+from .base import (
+ PGDialect,
+ PGCompiler,
+ PGIdentifierPreparer,
+ _DECIMAL_TYPES,
+ _FLOAT_TYPES,
+ _INT_TYPES,
+ UUID,
+)
from .hstore import HSTORE
from .json import JSON, JSONB
class _PGNumeric(Numeric):
-
def bind_processor(self, dialect):
return None
@@ -37,14 +43,15 @@ class _PGNumeric(Numeric):
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
- decimal.Decimal,
- self._effective_decimal_return_scale)
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# PyGreSQL returns Decimal natively for 1700 (numeric)
return None
else:
raise exc.InvalidRequestError(
- "Unknown PG numeric type: %d" % coltype)
+ "Unknown PG numeric type: %d" % coltype
+ )
else:
if coltype in _FLOAT_TYPES:
# PyGreSQL returns float natively for 701 (float8)
@@ -53,19 +60,21 @@ class _PGNumeric(Numeric):
return processors.to_float
else:
raise exc.InvalidRequestError(
- "Unknown PG numeric type: %d" % coltype)
+ "Unknown PG numeric type: %d" % coltype
+ )
class _PGHStore(HSTORE):
-
def bind_processor(self, dialect):
if not dialect.has_native_hstore:
return super(_PGHStore, self).bind_processor(dialect)
hstore = dialect.dbapi.Hstore
+
def process(value):
if isinstance(value, dict):
return hstore(value)
return value
+
return process
def result_processor(self, dialect, coltype):
@@ -74,7 +83,6 @@ class _PGHStore(HSTORE):
class _PGJSON(JSON):
-
def bind_processor(self, dialect):
if not dialect.has_native_json:
return super(_PGJSON, self).bind_processor(dialect)
@@ -84,7 +92,8 @@ class _PGJSON(JSON):
if value is self.NULL:
value = None
elif isinstance(value, Null) or (
- value is None and self.none_as_null):
+ value is None and self.none_as_null
+ ):
return None
if value is None or isinstance(value, (dict, list)):
return json(value)
@@ -98,7 +107,6 @@ class _PGJSON(JSON):
class _PGJSONB(JSONB):
-
def bind_processor(self, dialect):
if not dialect.has_native_json:
return super(_PGJSONB, self).bind_processor(dialect)
@@ -108,7 +116,8 @@ class _PGJSONB(JSONB):
if value is self.NULL:
value = None
elif isinstance(value, Null) or (
- value is None and self.none_as_null):
+ value is None and self.none_as_null
+ ):
return None
if value is None or isinstance(value, (dict, list)):
return json(value)
@@ -122,7 +131,6 @@ class _PGJSONB(JSONB):
class _PGUUID(UUID):
-
def bind_processor(self, dialect):
if not dialect.has_native_uuid:
return super(_PGUUID, self).bind_processor(dialect)
@@ -145,32 +153,35 @@ class _PGUUID(UUID):
if not dialect.has_native_uuid:
return super(_PGUUID, self).result_processor(dialect, coltype)
if not self.as_uuid:
+
def process(value):
if value is not None:
return str(value)
+
return process
class _PGCompiler(PGCompiler):
-
def visit_mod_binary(self, binary, operator, **kw):
- return self.process(binary.left, **kw) + " %% " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
def post_process_text(self, text):
- return text.replace('%', '%%')
+ return text.replace("%", "%%")
class _PGIdentifierPreparer(PGIdentifierPreparer):
-
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
- return value.replace('%', '%%')
+ return value.replace("%", "%%")
class PGDialect_pygresql(PGDialect):
- driver = 'pygresql'
+ driver = "pygresql"
statement_compiler = _PGCompiler
preparer = _PGIdentifierPreparer
@@ -178,6 +189,7 @@ class PGDialect_pygresql(PGDialect):
@classmethod
def dbapi(cls):
import pgdb
+
return pgdb
colspecs = util.update_copy(
@@ -189,14 +201,14 @@ class PGDialect_pygresql(PGDialect):
JSON: _PGJSON,
JSONB: _PGJSONB,
UUID: _PGUUID,
- }
+ },
)
def __init__(self, **kwargs):
super(PGDialect_pygresql, self).__init__(**kwargs)
try:
version = self.dbapi.version
- m = re.match(r'(\d+)\.(\d+)', version)
+ m = re.match(r"(\d+)\.(\d+)", version)
version = (int(m.group(1)), int(m.group(2)))
except (AttributeError, ValueError, TypeError):
version = (0, 0)
@@ -204,8 +216,10 @@ class PGDialect_pygresql(PGDialect):
if version < (5, 0):
has_native_hstore = has_native_json = has_native_uuid = False
if version != (0, 0):
- util.warn("PyGreSQL is only fully supported by SQLAlchemy"
- " since version 5.0.")
+ util.warn(
+ "PyGreSQL is only fully supported by SQLAlchemy"
+ " since version 5.0."
+ )
else:
self.supports_unicode_statements = True
self.supports_unicode_binds = True
@@ -215,10 +229,12 @@ class PGDialect_pygresql(PGDialect):
self.has_native_uuid = has_native_uuid
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if 'port' in opts:
- opts['host'] = '%s:%s' % (
- opts.get('host', '').rsplit(':', 1)[0], opts.pop('port'))
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["host"] = "%s:%s" % (
+ opts.get("host", "").rsplit(":", 1)[0],
+ opts.pop("port"),
+ )
opts.update(url.query)
return [], opts
diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
index b633323b4..93bf653a4 100644
--- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
+++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
@@ -37,12 +37,12 @@ class PGExecutionContext_pypostgresql(PGExecutionContext):
class PGDialect_pypostgresql(PGDialect):
- driver = 'pypostgresql'
+ driver = "pypostgresql"
supports_unicode_statements = True
supports_unicode_binds = True
description_encoding = None
- default_paramstyle = 'pyformat'
+ default_paramstyle = "pyformat"
# requires trunk version to support sane rowcounts
# TODO: use dbapi version information to set this flag appropriately
@@ -54,22 +54,27 @@ class PGDialect_pypostgresql(PGDialect):
PGDialect.colspecs,
{
sqltypes.Numeric: PGNumeric,
-
# prevents PGNumeric from being used
sqltypes.Float: sqltypes.Float,
- }
+ },
)
@classmethod
def dbapi(cls):
from postgresql.driver import dbapi20
+
return dbapi20
_DBAPI_ERROR_NAMES = [
"Error",
- "InterfaceError", "DatabaseError", "DataError",
- "OperationalError", "IntegrityError", "InternalError",
- "ProgrammingError", "NotSupportedError"
+ "InterfaceError",
+ "DatabaseError",
+ "DataError",
+ "OperationalError",
+ "IntegrityError",
+ "InternalError",
+ "ProgrammingError",
+ "NotSupportedError",
]
@util.memoized_property
@@ -83,15 +88,16 @@ class PGDialect_pypostgresql(PGDialect):
)
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if 'port' in opts:
- opts['port'] = int(opts['port'])
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
else:
- opts['port'] = 5432
+ opts["port"] = 5432
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e, connection, cursor):
return "connection is closed" in str(e)
+
dialect = PGDialect_pypostgresql
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
index eb2d86bbd..62d1275a6 100644
--- a/lib/sqlalchemy/dialects/postgresql/ranges.py
+++ b/lib/sqlalchemy/dialects/postgresql/ranges.py
@@ -7,7 +7,7 @@
from .base import ischema_names
from ... import types as sqltypes
-__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE')
+__all__ = ("INT4RANGE", "INT8RANGE", "NUMRANGE")
class RangeOperators(object):
@@ -34,35 +34,36 @@ class RangeOperators(object):
def __ne__(self, other):
"Boolean expression. Returns true if two ranges are not equal"
if other is None:
- return super(
- RangeOperators.comparator_factory, self).__ne__(other)
+ return super(RangeOperators.comparator_factory, self).__ne__(
+ other
+ )
else:
- return self.expr.op('<>')(other)
+ return self.expr.op("<>")(other)
def contains(self, other, **kw):
"""Boolean expression. Returns true if the right hand operand,
which can be an element or a range, is contained within the
column.
"""
- return self.expr.op('@>')(other)
+ return self.expr.op("@>")(other)
def contained_by(self, other):
"""Boolean expression. Returns true if the column is contained
within the right hand operand.
"""
- return self.expr.op('<@')(other)
+ return self.expr.op("<@")(other)
def overlaps(self, other):
"""Boolean expression. Returns true if the column overlaps
(has points in common with) the right hand operand.
"""
- return self.expr.op('&&')(other)
+ return self.expr.op("&&")(other)
def strictly_left_of(self, other):
"""Boolean expression. Returns true if the column is strictly
left of the right hand operand.
"""
- return self.expr.op('<<')(other)
+ return self.expr.op("<<")(other)
__lshift__ = strictly_left_of
@@ -70,7 +71,7 @@ class RangeOperators(object):
"""Boolean expression. Returns true if the column is strictly
right of the right hand operand.
"""
- return self.expr.op('>>')(other)
+ return self.expr.op(">>")(other)
__rshift__ = strictly_right_of
@@ -78,26 +79,26 @@ class RangeOperators(object):
"""Boolean expression. Returns true if the range in the column
does not extend right of the range in the operand.
"""
- return self.expr.op('&<')(other)
+ return self.expr.op("&<")(other)
def not_extend_left_of(self, other):
"""Boolean expression. Returns true if the range in the column
does not extend left of the range in the operand.
"""
- return self.expr.op('&>')(other)
+ return self.expr.op("&>")(other)
def adjacent_to(self, other):
"""Boolean expression. Returns true if the range in the column
is adjacent to the range in the operand.
"""
- return self.expr.op('-|-')(other)
+ return self.expr.op("-|-")(other)
def __add__(self, other):
"""Range expression. Returns the union of the two ranges.
Will raise an exception if the resulting range is not
contigous.
"""
- return self.expr.op('+')(other)
+ return self.expr.op("+")(other)
class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
@@ -107,9 +108,10 @@ class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
"""
- __visit_name__ = 'INT4RANGE'
+ __visit_name__ = "INT4RANGE"
-ischema_names['int4range'] = INT4RANGE
+
+ischema_names["int4range"] = INT4RANGE
class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
@@ -119,9 +121,10 @@ class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
"""
- __visit_name__ = 'INT8RANGE'
+ __visit_name__ = "INT8RANGE"
+
-ischema_names['int8range'] = INT8RANGE
+ischema_names["int8range"] = INT8RANGE
class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
@@ -131,9 +134,10 @@ class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
"""
- __visit_name__ = 'NUMRANGE'
+ __visit_name__ = "NUMRANGE"
+
-ischema_names['numrange'] = NUMRANGE
+ischema_names["numrange"] = NUMRANGE
class DATERANGE(RangeOperators, sqltypes.TypeEngine):
@@ -143,9 +147,10 @@ class DATERANGE(RangeOperators, sqltypes.TypeEngine):
"""
- __visit_name__ = 'DATERANGE'
+ __visit_name__ = "DATERANGE"
-ischema_names['daterange'] = DATERANGE
+
+ischema_names["daterange"] = DATERANGE
class TSRANGE(RangeOperators, sqltypes.TypeEngine):
@@ -155,9 +160,10 @@ class TSRANGE(RangeOperators, sqltypes.TypeEngine):
"""
- __visit_name__ = 'TSRANGE'
+ __visit_name__ = "TSRANGE"
+
-ischema_names['tsrange'] = TSRANGE
+ischema_names["tsrange"] = TSRANGE
class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
@@ -167,6 +173,7 @@ class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
"""
- __visit_name__ = 'TSTZRANGE'
+ __visit_name__ = "TSTZRANGE"
+
-ischema_names['tstzrange'] = TSTZRANGE
+ischema_names["tstzrange"] = TSTZRANGE
diff --git a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py
index ef6e8f1f9..4d984443a 100644
--- a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py
+++ b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py
@@ -19,7 +19,6 @@ from .base import PGDialect, PGExecutionContext
class PGExecutionContext_zxjdbc(PGExecutionContext):
-
def create_cursor(self):
cursor = self._dbapi_connection.cursor()
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
@@ -27,8 +26,8 @@ class PGExecutionContext_zxjdbc(PGExecutionContext):
class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect):
- jdbc_db_name = 'postgresql'
- jdbc_driver_name = 'org.postgresql.Driver'
+ jdbc_db_name = "postgresql"
+ jdbc_driver_name = "org.postgresql.Driver"
execution_ctx_cls = PGExecutionContext_zxjdbc
@@ -37,10 +36,12 @@ class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect):
def __init__(self, *args, **kwargs):
super(PGDialect_zxjdbc, self).__init__(*args, **kwargs)
from com.ziclix.python.sql.handler import PostgresqlDataHandler
+
self.DataHandler = PostgresqlDataHandler
def _get_server_version_info(self, connection):
- parts = connection.connection.dbversion.split('.')
+ parts = connection.connection.dbversion.split(".")
return tuple(int(x) for x in parts)
+
dialect = PGDialect_zxjdbc
diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py
index a73581521..41f017597 100644
--- a/lib/sqlalchemy/dialects/sqlite/__init__.py
+++ b/lib/sqlalchemy/dialects/sqlite/__init__.py
@@ -8,14 +8,44 @@
from . import base, pysqlite, pysqlcipher # noqa
from sqlalchemy.dialects.sqlite.base import (
- BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER, JSON, REAL,
- NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR
+ BLOB,
+ BOOLEAN,
+ CHAR,
+ DATE,
+ DATETIME,
+ DECIMAL,
+ FLOAT,
+ INTEGER,
+ JSON,
+ REAL,
+ NUMERIC,
+ SMALLINT,
+ TEXT,
+ TIME,
+ TIMESTAMP,
+ VARCHAR,
)
# default dialect
base.dialect = dialect = pysqlite.dialect
-__all__ = ('BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL',
- 'FLOAT', 'INTEGER', 'JSON', 'NUMERIC', 'SMALLINT', 'TEXT', 'TIME',
- 'TIMESTAMP', 'VARCHAR', 'REAL', 'dialect')
+__all__ = (
+ "BLOB",
+ "BOOLEAN",
+ "CHAR",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "FLOAT",
+ "INTEGER",
+ "JSON",
+ "NUMERIC",
+ "SMALLINT",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "VARCHAR",
+ "REAL",
+ "dialect",
+)
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index c487af898..cb9389af1 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -579,9 +579,20 @@ from ... import util
from ...engine import default, reflection
from ...sql import compiler
-from ...types import (BLOB, BOOLEAN, CHAR, DECIMAL, FLOAT,
- INTEGER, REAL, NUMERIC, SMALLINT, TEXT,
- TIMESTAMP, VARCHAR)
+from ...types import (
+ BLOB,
+ BOOLEAN,
+ CHAR,
+ DECIMAL,
+ FLOAT,
+ INTEGER,
+ REAL,
+ NUMERIC,
+ SMALLINT,
+ TEXT,
+ TIMESTAMP,
+ VARCHAR,
+)
from .json import JSON, JSONIndexType, JSONPathType
@@ -610,10 +621,15 @@ class _DateTimeMixin(object):
"""
spec = self._storage_format % {
- "year": 0, "month": 0, "day": 0, "hour": 0,
- "minute": 0, "second": 0, "microsecond": 0
+ "year": 0,
+ "month": 0,
+ "day": 0,
+ "hour": 0,
+ "minute": 0,
+ "second": 0,
+ "microsecond": 0,
}
- return bool(re.search(r'[^0-9]', spec))
+ return bool(re.search(r"[^0-9]", spec))
def adapt(self, cls, **kw):
if issubclass(cls, _DateTimeMixin):
@@ -628,6 +644,7 @@ class _DateTimeMixin(object):
def process(value):
return "'%s'" % bp(value)
+
return process
@@ -671,13 +688,17 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime):
)
def __init__(self, *args, **kwargs):
- truncate_microseconds = kwargs.pop('truncate_microseconds', False)
+ truncate_microseconds = kwargs.pop("truncate_microseconds", False)
super(DATETIME, self).__init__(*args, **kwargs)
if truncate_microseconds:
- assert 'storage_format' not in kwargs, "You can specify only "\
+ assert "storage_format" not in kwargs, (
+ "You can specify only "
"one of truncate_microseconds or storage_format."
- assert 'regexp' not in kwargs, "You can specify only one of "\
+ )
+ assert "regexp" not in kwargs, (
+ "You can specify only one of "
"truncate_microseconds or regexp."
+ )
self._storage_format = (
"%(year)04d-%(month)02d-%(day)02d "
"%(hour)02d:%(minute)02d:%(second)02d"
@@ -693,33 +714,37 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime):
return None
elif isinstance(value, datetime_datetime):
return format % {
- 'year': value.year,
- 'month': value.month,
- 'day': value.day,
- 'hour': value.hour,
- 'minute': value.minute,
- 'second': value.second,
- 'microsecond': value.microsecond,
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
+ "hour": value.hour,
+ "minute": value.minute,
+ "second": value.second,
+ "microsecond": value.microsecond,
}
elif isinstance(value, datetime_date):
return format % {
- 'year': value.year,
- 'month': value.month,
- 'day': value.day,
- 'hour': 0,
- 'minute': 0,
- 'second': 0,
- 'microsecond': 0,
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
+ "hour": 0,
+ "minute": 0,
+ "second": 0,
+ "microsecond": 0,
}
else:
- raise TypeError("SQLite DateTime type only accepts Python "
- "datetime and date objects as input.")
+ raise TypeError(
+ "SQLite DateTime type only accepts Python "
+ "datetime and date objects as input."
+ )
+
return process
def result_processor(self, dialect, coltype):
if self._reg:
return processors.str_to_datetime_processor_factory(
- self._reg, datetime.datetime)
+ self._reg, datetime.datetime
+ )
else:
return processors.str_to_datetime
@@ -768,19 +793,23 @@ class DATE(_DateTimeMixin, sqltypes.Date):
return None
elif isinstance(value, datetime_date):
return format % {
- 'year': value.year,
- 'month': value.month,
- 'day': value.day,
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
}
else:
- raise TypeError("SQLite Date type only accepts Python "
- "date objects as input.")
+ raise TypeError(
+ "SQLite Date type only accepts Python "
+ "date objects as input."
+ )
+
return process
def result_processor(self, dialect, coltype):
if self._reg:
return processors.str_to_datetime_processor_factory(
- self._reg, datetime.date)
+ self._reg, datetime.date
+ )
else:
return processors.str_to_date
@@ -820,13 +849,17 @@ class TIME(_DateTimeMixin, sqltypes.Time):
_storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
def __init__(self, *args, **kwargs):
- truncate_microseconds = kwargs.pop('truncate_microseconds', False)
+ truncate_microseconds = kwargs.pop("truncate_microseconds", False)
super(TIME, self).__init__(*args, **kwargs)
if truncate_microseconds:
- assert 'storage_format' not in kwargs, "You can specify only "\
+ assert "storage_format" not in kwargs, (
+ "You can specify only "
"one of truncate_microseconds or storage_format."
- assert 'regexp' not in kwargs, "You can specify only one of "\
+ )
+ assert "regexp" not in kwargs, (
+ "You can specify only one of "
"truncate_microseconds or regexp."
+ )
self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d"
def bind_processor(self, dialect):
@@ -838,23 +871,28 @@ class TIME(_DateTimeMixin, sqltypes.Time):
return None
elif isinstance(value, datetime_time):
return format % {
- 'hour': value.hour,
- 'minute': value.minute,
- 'second': value.second,
- 'microsecond': value.microsecond,
+ "hour": value.hour,
+ "minute": value.minute,
+ "second": value.second,
+ "microsecond": value.microsecond,
}
else:
- raise TypeError("SQLite Time type only accepts Python "
- "time objects as input.")
+ raise TypeError(
+ "SQLite Time type only accepts Python "
+ "time objects as input."
+ )
+
return process
def result_processor(self, dialect, coltype):
if self._reg:
return processors.str_to_datetime_processor_factory(
- self._reg, datetime.time)
+ self._reg, datetime.time
+ )
else:
return processors.str_to_time
+
colspecs = {
sqltypes.Date: DATE,
sqltypes.DateTime: DATETIME,
@@ -865,31 +903,31 @@ colspecs = {
}
ischema_names = {
- 'BIGINT': sqltypes.BIGINT,
- 'BLOB': sqltypes.BLOB,
- 'BOOL': sqltypes.BOOLEAN,
- 'BOOLEAN': sqltypes.BOOLEAN,
- 'CHAR': sqltypes.CHAR,
- 'DATE': sqltypes.DATE,
- 'DATE_CHAR': sqltypes.DATE,
- 'DATETIME': sqltypes.DATETIME,
- 'DATETIME_CHAR': sqltypes.DATETIME,
- 'DOUBLE': sqltypes.FLOAT,
- 'DECIMAL': sqltypes.DECIMAL,
- 'FLOAT': sqltypes.FLOAT,
- 'INT': sqltypes.INTEGER,
- 'INTEGER': sqltypes.INTEGER,
- 'JSON': JSON,
- 'NUMERIC': sqltypes.NUMERIC,
- 'REAL': sqltypes.REAL,
- 'SMALLINT': sqltypes.SMALLINT,
- 'TEXT': sqltypes.TEXT,
- 'TIME': sqltypes.TIME,
- 'TIME_CHAR': sqltypes.TIME,
- 'TIMESTAMP': sqltypes.TIMESTAMP,
- 'VARCHAR': sqltypes.VARCHAR,
- 'NVARCHAR': sqltypes.NVARCHAR,
- 'NCHAR': sqltypes.NCHAR,
+ "BIGINT": sqltypes.BIGINT,
+ "BLOB": sqltypes.BLOB,
+ "BOOL": sqltypes.BOOLEAN,
+ "BOOLEAN": sqltypes.BOOLEAN,
+ "CHAR": sqltypes.CHAR,
+ "DATE": sqltypes.DATE,
+ "DATE_CHAR": sqltypes.DATE,
+ "DATETIME": sqltypes.DATETIME,
+ "DATETIME_CHAR": sqltypes.DATETIME,
+ "DOUBLE": sqltypes.FLOAT,
+ "DECIMAL": sqltypes.DECIMAL,
+ "FLOAT": sqltypes.FLOAT,
+ "INT": sqltypes.INTEGER,
+ "INTEGER": sqltypes.INTEGER,
+ "JSON": JSON,
+ "NUMERIC": sqltypes.NUMERIC,
+ "REAL": sqltypes.REAL,
+ "SMALLINT": sqltypes.SMALLINT,
+ "TEXT": sqltypes.TEXT,
+ "TIME": sqltypes.TIME,
+ "TIME_CHAR": sqltypes.TIME,
+ "TIMESTAMP": sqltypes.TIMESTAMP,
+ "VARCHAR": sqltypes.VARCHAR,
+ "NVARCHAR": sqltypes.NVARCHAR,
+ "NCHAR": sqltypes.NCHAR,
}
@@ -897,17 +935,18 @@ class SQLiteCompiler(compiler.SQLCompiler):
extract_map = util.update_copy(
compiler.SQLCompiler.extract_map,
{
- 'month': '%m',
- 'day': '%d',
- 'year': '%Y',
- 'second': '%S',
- 'hour': '%H',
- 'doy': '%j',
- 'minute': '%M',
- 'epoch': '%s',
- 'dow': '%w',
- 'week': '%W',
- })
+ "month": "%m",
+ "day": "%d",
+ "year": "%Y",
+ "second": "%S",
+ "hour": "%H",
+ "doy": "%j",
+ "minute": "%M",
+ "epoch": "%s",
+ "dow": "%w",
+ "week": "%W",
+ },
+ )
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
@@ -916,10 +955,10 @@ class SQLiteCompiler(compiler.SQLCompiler):
return 'DATETIME(CURRENT_TIMESTAMP, "localtime")'
def visit_true(self, expr, **kw):
- return '1'
+ return "1"
def visit_false(self, expr, **kw):
- return '0'
+ return "0"
def visit_char_length_func(self, fn, **kw):
return "length%s" % self.function_argspec(fn)
@@ -934,11 +973,12 @@ class SQLiteCompiler(compiler.SQLCompiler):
try:
return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
self.extract_map[extract.field],
- self.process(extract.expr, **kw)
+ self.process(extract.expr, **kw),
)
except KeyError:
raise exc.CompileError(
- "%s is not a valid extract argument." % extract.field)
+ "%s is not a valid extract argument." % extract.field
+ )
def limit_clause(self, select, **kw):
text = ""
@@ -954,35 +994,41 @@ class SQLiteCompiler(compiler.SQLCompiler):
def for_update_clause(self, select, **kw):
# sqlite has no "FOR UPDATE" AFAICT
- return ''
+ return ""
def visit_is_distinct_from_binary(self, binary, operator, **kw):
- return "%s IS NOT %s" % (self.process(binary.left),
- self.process(binary.right))
+ return "%s IS NOT %s" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
def visit_isnot_distinct_from_binary(self, binary, operator, **kw):
- return "%s IS %s" % (self.process(binary.left),
- self.process(binary.right))
+ return "%s IS %s" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
def visit_json_getitem_op_binary(self, binary, operator, **kw):
return "JSON_QUOTE(JSON_EXTRACT(%s, %s))" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
return "JSON_QUOTE(JSON_EXTRACT(%s, %s))" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def visit_empty_set_expr(self, type_):
- return 'SELECT 1 FROM (SELECT 1) WHERE 1!=1'
+ return "SELECT 1 FROM (SELECT 1) WHERE 1!=1"
class SQLiteDDLCompiler(compiler.DDLCompiler):
-
def get_column_specification(self, column, **kwargs):
coltype = self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ column.type, type_expression=column
+ )
colspec = self.preparer.format_column(column) + " " + coltype
default = self.get_column_default_string(column)
if default is not None:
@@ -991,29 +1037,33 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
if not column.nullable:
colspec += " NOT NULL"
- on_conflict_clause = column.dialect_options['sqlite'][
- 'on_conflict_not_null']
+ on_conflict_clause = column.dialect_options["sqlite"][
+ "on_conflict_not_null"
+ ]
if on_conflict_clause is not None:
colspec += " ON CONFLICT " + on_conflict_clause
if column.primary_key:
if (
- column.autoincrement is True and
- len(column.table.primary_key.columns) != 1
+ column.autoincrement is True
+ and len(column.table.primary_key.columns) != 1
):
raise exc.CompileError(
"SQLite does not support autoincrement for "
- "composite primary keys")
+ "composite primary keys"
+ )
- if (column.table.dialect_options['sqlite']['autoincrement'] and
- len(column.table.primary_key.columns) == 1 and
- issubclass(
- column.type._type_affinity, sqltypes.Integer) and
- not column.foreign_keys):
+ if (
+ column.table.dialect_options["sqlite"]["autoincrement"]
+ and len(column.table.primary_key.columns) == 1
+ and issubclass(column.type._type_affinity, sqltypes.Integer)
+ and not column.foreign_keys
+ ):
colspec += " PRIMARY KEY"
- on_conflict_clause = column.dialect_options['sqlite'][
- 'on_conflict_primary_key']
+ on_conflict_clause = column.dialect_options["sqlite"][
+ "on_conflict_primary_key"
+ ]
if on_conflict_clause is not None:
colspec += " ON CONFLICT " + on_conflict_clause
@@ -1027,21 +1077,25 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
# with the column itself.
if len(constraint.columns) == 1:
c = list(constraint)[0]
- if (c.primary_key and
- c.table.dialect_options['sqlite']['autoincrement'] and
- issubclass(c.type._type_affinity, sqltypes.Integer) and
- not c.foreign_keys):
+ if (
+ c.primary_key
+ and c.table.dialect_options["sqlite"]["autoincrement"]
+ and issubclass(c.type._type_affinity, sqltypes.Integer)
+ and not c.foreign_keys
+ ):
return None
- text = super(
- SQLiteDDLCompiler,
- self).visit_primary_key_constraint(constraint)
+ text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint(
+ constraint
+ )
- on_conflict_clause = constraint.dialect_options['sqlite'][
- 'on_conflict']
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
if on_conflict_clause is None and len(constraint.columns) == 1:
- on_conflict_clause = list(constraint)[0].\
- dialect_options['sqlite']['on_conflict_primary_key']
+ on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][
+ "on_conflict_primary_key"
+ ]
if on_conflict_clause is not None:
text += " ON CONFLICT " + on_conflict_clause
@@ -1049,15 +1103,17 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return text
def visit_unique_constraint(self, constraint):
- text = super(
- SQLiteDDLCompiler,
- self).visit_unique_constraint(constraint)
+ text = super(SQLiteDDLCompiler, self).visit_unique_constraint(
+ constraint
+ )
- on_conflict_clause = constraint.dialect_options['sqlite'][
- 'on_conflict']
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
if on_conflict_clause is None and len(constraint.columns) == 1:
- on_conflict_clause = list(constraint)[0].\
- dialect_options['sqlite']['on_conflict_unique']
+ on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][
+ "on_conflict_unique"
+ ]
if on_conflict_clause is not None:
text += " ON CONFLICT " + on_conflict_clause
@@ -1065,12 +1121,13 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return text
def visit_check_constraint(self, constraint):
- text = super(
- SQLiteDDLCompiler,
- self).visit_check_constraint(constraint)
+ text = super(SQLiteDDLCompiler, self).visit_check_constraint(
+ constraint
+ )
- on_conflict_clause = constraint.dialect_options['sqlite'][
- 'on_conflict']
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
if on_conflict_clause is not None:
text += " ON CONFLICT " + on_conflict_clause
@@ -1078,14 +1135,15 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return text
def visit_column_check_constraint(self, constraint):
- text = super(
- SQLiteDDLCompiler,
- self).visit_column_check_constraint(constraint)
+ text = super(SQLiteDDLCompiler, self).visit_column_check_constraint(
+ constraint
+ )
- if constraint.dialect_options['sqlite']['on_conflict'] is not None:
+ if constraint.dialect_options["sqlite"]["on_conflict"] is not None:
raise exc.CompileError(
"SQLite does not support on conflict clause for "
- "column check constraint")
+ "column check constraint"
+ )
return text
@@ -1097,40 +1155,40 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
if local_table.schema != remote_table.schema:
return None
else:
- return super(
- SQLiteDDLCompiler,
- self).visit_foreign_key_constraint(constraint)
+ return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint(
+ constraint
+ )
def define_constraint_remote_table(self, constraint, table, preparer):
"""Format the remote table clause of a CREATE CONSTRAINT clause."""
return preparer.format_table(table, use_schema=False)
- def visit_create_index(self, create, include_schema=False,
- include_table_schema=True):
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True
+ ):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
text = "CREATE "
if index.unique:
text += "UNIQUE "
- text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=True),
- preparer.format_table(index.table,
- use_schema=False),
- ', '.join(
- self.sql_compiler.process(
- expr, include_table=False, literal_binds=True) for
- expr in index.expressions)
- )
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=True),
+ preparer.format_table(index.table, use_schema=False),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
whereclause = index.dialect_options["sqlite"]["where"]
if whereclause is not None:
where_compiled = self.sql_compiler.process(
- whereclause, include_table=False,
- literal_binds=True)
+ whereclause, include_table=False, literal_binds=True
+ )
text += " WHERE " + where_compiled
return text
@@ -1141,22 +1199,28 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
return self.visit_BLOB(type_)
def visit_DATETIME(self, type_, **kw):
- if not isinstance(type_, _DateTimeMixin) or \
- type_.format_is_text_affinity:
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
return super(SQLiteTypeCompiler, self).visit_DATETIME(type_)
else:
return "DATETIME_CHAR"
def visit_DATE(self, type_, **kw):
- if not isinstance(type_, _DateTimeMixin) or \
- type_.format_is_text_affinity:
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
return super(SQLiteTypeCompiler, self).visit_DATE(type_)
else:
return "DATE_CHAR"
def visit_TIME(self, type_, **kw):
- if not isinstance(type_, _DateTimeMixin) or \
- type_.format_is_text_affinity:
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
return super(SQLiteTypeCompiler, self).visit_TIME(type_)
else:
return "TIME_CHAR"
@@ -1169,33 +1233,135 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
- reserved_words = set([
- 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
- 'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
- 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
- 'conflict', 'constraint', 'create', 'cross', 'current_date',
- 'current_time', 'current_timestamp', 'database', 'default',
- 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
- 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
- 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
- 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
- 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect',
- 'into', 'is', 'isnull', 'join', 'key', 'left', 'like', 'limit',
- 'match', 'natural', 'not', 'notnull', 'null', 'of', 'offset', 'on',
- 'or', 'order', 'outer', 'plan', 'pragma', 'primary', 'query',
- 'raise', 'references', 'reindex', 'rename', 'replace', 'restrict',
- 'right', 'rollback', 'row', 'select', 'set', 'table', 'temp',
- 'temporary', 'then', 'to', 'transaction', 'trigger', 'true', 'union',
- 'unique', 'update', 'using', 'vacuum', 'values', 'view', 'virtual',
- 'when', 'where',
- ])
+ reserved_words = set(
+ [
+ "add",
+ "after",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "as",
+ "asc",
+ "attach",
+ "autoincrement",
+ "before",
+ "begin",
+ "between",
+ "by",
+ "cascade",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "commit",
+ "conflict",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "database",
+ "default",
+ "deferrable",
+ "deferred",
+ "delete",
+ "desc",
+ "detach",
+ "distinct",
+ "drop",
+ "each",
+ "else",
+ "end",
+ "escape",
+ "except",
+ "exclusive",
+ "explain",
+ "false",
+ "fail",
+ "for",
+ "foreign",
+ "from",
+ "full",
+ "glob",
+ "group",
+ "having",
+ "if",
+ "ignore",
+ "immediate",
+ "in",
+ "index",
+ "indexed",
+ "initially",
+ "inner",
+ "insert",
+ "instead",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "key",
+ "left",
+ "like",
+ "limit",
+ "match",
+ "natural",
+ "not",
+ "notnull",
+ "null",
+ "of",
+ "offset",
+ "on",
+ "or",
+ "order",
+ "outer",
+ "plan",
+ "pragma",
+ "primary",
+ "query",
+ "raise",
+ "references",
+ "reindex",
+ "rename",
+ "replace",
+ "restrict",
+ "right",
+ "rollback",
+ "row",
+ "select",
+ "set",
+ "table",
+ "temp",
+ "temporary",
+ "then",
+ "to",
+ "transaction",
+ "trigger",
+ "true",
+ "union",
+ "unique",
+ "update",
+ "using",
+ "vacuum",
+ "values",
+ "view",
+ "virtual",
+ "when",
+ "where",
+ ]
+ )
class SQLiteExecutionContext(default.DefaultExecutionContext):
@util.memoized_property
def _preserve_raw_colnames(self):
- return not self.dialect._broken_dotted_colnames or \
- self.execution_options.get("sqlite_raw_colnames", False)
+ return (
+ not self.dialect._broken_dotted_colnames
+ or self.execution_options.get("sqlite_raw_colnames", False)
+ )
def _translate_colname(self, colname):
# TODO: detect SQLite version 3.10.0 or greater;
@@ -1212,7 +1378,7 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
class SQLiteDialect(default.DefaultDialect):
- name = 'sqlite'
+ name = "sqlite"
supports_alter = False
supports_unicode_statements = True
supports_unicode_binds = True
@@ -1221,7 +1387,7 @@ class SQLiteDialect(default.DefaultDialect):
supports_cast = True
supports_multivalues_insert = True
- default_paramstyle = 'qmark'
+ default_paramstyle = "qmark"
execution_ctx_cls = SQLiteExecutionContext
statement_compiler = SQLiteCompiler
ddl_compiler = SQLiteDDLCompiler
@@ -1235,27 +1401,30 @@ class SQLiteDialect(default.DefaultDialect):
supports_default_values = True
construct_arguments = [
- (sa_schema.Table, {
- "autoincrement": False
- }),
- (sa_schema.Index, {
- "where": None,
- }),
- (sa_schema.Column, {
- "on_conflict_primary_key": None,
- "on_conflict_not_null": None,
- "on_conflict_unique": None,
- }),
- (sa_schema.Constraint, {
- "on_conflict": None,
- }),
+ (sa_schema.Table, {"autoincrement": False}),
+ (sa_schema.Index, {"where": None}),
+ (
+ sa_schema.Column,
+ {
+ "on_conflict_primary_key": None,
+ "on_conflict_not_null": None,
+ "on_conflict_unique": None,
+ },
+ ),
+ (sa_schema.Constraint, {"on_conflict": None}),
]
_broken_fk_pragma_quotes = False
_broken_dotted_colnames = False
- def __init__(self, isolation_level=None, native_datetime=False,
- _json_serializer=None, _json_deserializer=None, **kwargs):
+ def __init__(
+ self,
+ isolation_level=None,
+ native_datetime=False,
+ _json_serializer=None,
+ _json_deserializer=None,
+ **kwargs
+ ):
default.DefaultDialect.__init__(self, **kwargs)
self.isolation_level = isolation_level
self._json_serializer = _json_serializer
@@ -1269,35 +1438,42 @@ class SQLiteDialect(default.DefaultDialect):
if self.dbapi is not None:
self.supports_right_nested_joins = (
- self.dbapi.sqlite_version_info >= (3, 7, 16))
- self._broken_dotted_colnames = (
- self.dbapi.sqlite_version_info < (3, 10, 0)
+ self.dbapi.sqlite_version_info >= (3, 7, 16)
+ )
+ self._broken_dotted_colnames = self.dbapi.sqlite_version_info < (
+ 3,
+ 10,
+ 0,
+ )
+ self.supports_default_values = self.dbapi.sqlite_version_info >= (
+ 3,
+ 3,
+ 8,
)
- self.supports_default_values = (
- self.dbapi.sqlite_version_info >= (3, 3, 8))
- self.supports_cast = (
- self.dbapi.sqlite_version_info >= (3, 2, 3))
+ self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3)
self.supports_multivalues_insert = (
# http://www.sqlite.org/releaselog/3_7_11.html
- self.dbapi.sqlite_version_info >= (3, 7, 11))
+ self.dbapi.sqlite_version_info
+ >= (3, 7, 11)
+ )
# see http://www.sqlalchemy.org/trac/ticket/2568
# as well as http://www.sqlite.org/src/info/600482d161
- self._broken_fk_pragma_quotes = (
- self.dbapi.sqlite_version_info < (3, 6, 14))
+ self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < (
+ 3,
+ 6,
+ 14,
+ )
- _isolation_lookup = {
- 'READ UNCOMMITTED': 1,
- 'SERIALIZABLE': 0,
- }
+ _isolation_lookup = {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0}
def set_isolation_level(self, connection, level):
try:
- isolation_level = self._isolation_lookup[level.replace('_', ' ')]
+ isolation_level = self._isolation_lookup[level.replace("_", " ")]
except KeyError:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
- "Valid isolation levels for %s are %s" %
- (level, self.name, ", ".join(self._isolation_lookup))
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
)
cursor = connection.cursor()
cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
@@ -1305,7 +1481,7 @@ class SQLiteDialect(default.DefaultDialect):
def get_isolation_level(self, connection):
cursor = connection.cursor()
- cursor.execute('PRAGMA read_uncommitted')
+ cursor.execute("PRAGMA read_uncommitted")
res = cursor.fetchone()
if res:
value = res[0]
@@ -1327,8 +1503,10 @@ class SQLiteDialect(default.DefaultDialect):
def on_connect(self):
if self.isolation_level is not None:
+
def connect(conn):
self.set_isolation_level(conn, self.isolation_level)
+
return connect
else:
return None
@@ -1344,44 +1522,51 @@ class SQLiteDialect(default.DefaultDialect):
def get_table_names(self, connection, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = '%s.sqlite_master' % qschema
+ master = "%s.sqlite_master" % qschema
else:
master = "sqlite_master"
- s = ("SELECT name FROM %s "
- "WHERE type='table' ORDER BY name") % (master,)
+ s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % (
+ master,
+ )
rs = connection.execute(s)
return [row[0] for row in rs]
@reflection.cache
def get_temp_table_names(self, connection, **kw):
- s = "SELECT name FROM sqlite_temp_master "\
+ s = (
+ "SELECT name FROM sqlite_temp_master "
"WHERE type='table' ORDER BY name "
+ )
rs = connection.execute(s)
return [row[0] for row in rs]
@reflection.cache
def get_temp_view_names(self, connection, **kw):
- s = "SELECT name FROM sqlite_temp_master "\
+ s = (
+ "SELECT name FROM sqlite_temp_master "
"WHERE type='view' ORDER BY name "
+ )
rs = connection.execute(s)
return [row[0] for row in rs]
def has_table(self, connection, table_name, schema=None):
info = self._get_table_pragma(
- connection, "table_info", table_name, schema=schema)
+ connection, "table_info", table_name, schema=schema
+ )
return bool(info)
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = '%s.sqlite_master' % qschema
+ master = "%s.sqlite_master" % qschema
else:
master = "sqlite_master"
- s = ("SELECT name FROM %s "
- "WHERE type='view' ORDER BY name") % (master,)
+ s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % (
+ master,
+ )
rs = connection.execute(s)
return [row[0] for row in rs]
@@ -1390,21 +1575,27 @@ class SQLiteDialect(default.DefaultDialect):
def get_view_definition(self, connection, view_name, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = '%s.sqlite_master' % qschema
- s = ("SELECT sql FROM %s WHERE name = '%s'"
- "AND type='view'") % (master, view_name)
+ master = "%s.sqlite_master" % qschema
+ s = ("SELECT sql FROM %s WHERE name = '%s'" "AND type='view'") % (
+ master,
+ view_name,
+ )
rs = connection.execute(s)
else:
try:
- s = ("SELECT sql FROM "
- " (SELECT * FROM sqlite_master UNION ALL "
- " SELECT * FROM sqlite_temp_master) "
- "WHERE name = '%s' "
- "AND type='view'") % view_name
+ s = (
+ "SELECT sql FROM "
+ " (SELECT * FROM sqlite_master UNION ALL "
+ " SELECT * FROM sqlite_temp_master) "
+ "WHERE name = '%s' "
+ "AND type='view'"
+ ) % view_name
rs = connection.execute(s)
except exc.DBAPIError:
- s = ("SELECT sql FROM sqlite_master WHERE name = '%s' "
- "AND type='view'") % view_name
+ s = (
+ "SELECT sql FROM sqlite_master WHERE name = '%s' "
+ "AND type='view'"
+ ) % view_name
rs = connection.execute(s)
result = rs.fetchall()
@@ -1414,15 +1605,24 @@ class SQLiteDialect(default.DefaultDialect):
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
info = self._get_table_pragma(
- connection, "table_info", table_name, schema=schema)
+ connection, "table_info", table_name, schema=schema
+ )
columns = []
for row in info:
(name, type_, nullable, default, primary_key) = (
- row[1], row[2].upper(), not row[3], row[4], row[5])
+ row[1],
+ row[2].upper(),
+ not row[3],
+ row[4],
+ row[5],
+ )
- columns.append(self._get_column_info(name, type_, nullable,
- default, primary_key))
+ columns.append(
+ self._get_column_info(
+ name, type_, nullable, default, primary_key
+ )
+ )
return columns
def _get_column_info(self, name, type_, nullable, default, primary_key):
@@ -1432,12 +1632,12 @@ class SQLiteDialect(default.DefaultDialect):
default = util.text_type(default)
return {
- 'name': name,
- 'type': coltype,
- 'nullable': nullable,
- 'default': default,
- 'autoincrement': 'auto',
- 'primary_key': primary_key,
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": "auto",
+ "primary_key": primary_key,
}
def _resolve_type_affinity(self, type_):
@@ -1457,36 +1657,37 @@ class SQLiteDialect(default.DefaultDialect):
DATE and DOUBLE).
"""
- match = re.match(r'([\w ]+)(\(.*?\))?', type_)
+ match = re.match(r"([\w ]+)(\(.*?\))?", type_)
if match:
coltype = match.group(1)
args = match.group(2)
else:
- coltype = ''
- args = ''
+ coltype = ""
+ args = ""
if coltype in self.ischema_names:
coltype = self.ischema_names[coltype]
- elif 'INT' in coltype:
+ elif "INT" in coltype:
coltype = sqltypes.INTEGER
- elif 'CHAR' in coltype or 'CLOB' in coltype or 'TEXT' in coltype:
+ elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype:
coltype = sqltypes.TEXT
- elif 'BLOB' in coltype or not coltype:
+ elif "BLOB" in coltype or not coltype:
coltype = sqltypes.NullType
- elif 'REAL' in coltype or 'FLOA' in coltype or 'DOUB' in coltype:
+ elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype:
coltype = sqltypes.REAL
else:
coltype = sqltypes.NUMERIC
if args is not None:
- args = re.findall(r'(\d+)', args)
+ args = re.findall(r"(\d+)", args)
try:
coltype = coltype(*[int(a) for a in args])
except TypeError:
util.warn(
"Could not instantiate type %s with "
- "reflected arguments %s; using no arguments." %
- (coltype, args))
+ "reflected arguments %s; using no arguments."
+ % (coltype, args)
+ )
coltype = coltype()
else:
coltype = coltype()
@@ -1498,58 +1699,59 @@ class SQLiteDialect(default.DefaultDialect):
constraint_name = None
table_data = self._get_table_sql(connection, table_name, schema=schema)
if table_data:
- PK_PATTERN = r'CONSTRAINT (\w+) PRIMARY KEY'
+ PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY"
result = re.search(PK_PATTERN, table_data, re.I)
constraint_name = result.group(1) if result else None
cols = self.get_columns(connection, table_name, schema, **kw)
pkeys = []
for col in cols:
- if col['primary_key']:
- pkeys.append(col['name'])
+ if col["primary_key"]:
+ pkeys.append(col["name"])
- return {'constrained_columns': pkeys, 'name': constraint_name}
+ return {"constrained_columns": pkeys, "name": constraint_name}
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# sqlite makes this *extremely difficult*.
# First, use the pragma to get the actual FKs.
pragma_fks = self._get_table_pragma(
- connection, "foreign_key_list",
- table_name, schema=schema
+ connection, "foreign_key_list", table_name, schema=schema
)
fks = {}
for row in pragma_fks:
- (numerical_id, rtbl, lcol, rcol) = (
- row[0], row[2], row[3], row[4])
+ (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4])
if rcol is None:
rcol = lcol
if self._broken_fk_pragma_quotes:
- rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl)
+ rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl)
if numerical_id in fks:
fk = fks[numerical_id]
else:
fk = fks[numerical_id] = {
- 'name': None,
- 'constrained_columns': [],
- 'referred_schema': schema,
- 'referred_table': rtbl,
- 'referred_columns': [],
- 'options': {}
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": schema,
+ "referred_table": rtbl,
+ "referred_columns": [],
+ "options": {},
}
fks[numerical_id] = fk
- fk['constrained_columns'].append(lcol)
- fk['referred_columns'].append(rcol)
+ fk["constrained_columns"].append(lcol)
+ fk["referred_columns"].append(rcol)
def fk_sig(constrained_columns, referred_table, referred_columns):
- return tuple(constrained_columns) + (referred_table,) + \
- tuple(referred_columns)
+ return (
+ tuple(constrained_columns)
+ + (referred_table,)
+ + tuple(referred_columns)
+ )
# then, parse the actual SQL and attempt to find DDL that matches
# the names as well. SQLite saves the DDL in whatever format
@@ -1558,10 +1760,13 @@ class SQLiteDialect(default.DefaultDialect):
keys_by_signature = dict(
(
fk_sig(
- fk['constrained_columns'],
- fk['referred_table'], fk['referred_columns']),
- fk
- ) for fk in fks.values()
+ fk["constrained_columns"],
+ fk["referred_table"],
+ fk["referred_columns"],
+ ),
+ fk,
+ )
+ for fk in fks.values()
)
table_data = self._get_table_sql(connection, table_name, schema=schema)
@@ -1571,55 +1776,66 @@ class SQLiteDialect(default.DefaultDialect):
def parse_fks():
FK_PATTERN = (
- r'(?:CONSTRAINT (\w+) +)?'
- r'FOREIGN KEY *\( *(.+?) *\) +'
+ r"(?:CONSTRAINT (\w+) +)?"
+ r"FOREIGN KEY *\( *(.+?) *\) +"
r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\) *'
- r'((?:ON (?:DELETE|UPDATE) '
- r'(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)'
+ r"((?:ON (?:DELETE|UPDATE) "
+ r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)"
)
for match in re.finditer(FK_PATTERN, table_data, re.I):
(
- constraint_name, constrained_columns,
- referred_quoted_name, referred_name,
- referred_columns, onupdatedelete) = \
- match.group(1, 2, 3, 4, 5, 6)
+ constraint_name,
+ constrained_columns,
+ referred_quoted_name,
+ referred_name,
+ referred_columns,
+ onupdatedelete,
+ ) = match.group(1, 2, 3, 4, 5, 6)
constrained_columns = list(
- self._find_cols_in_sig(constrained_columns))
+ self._find_cols_in_sig(constrained_columns)
+ )
if not referred_columns:
referred_columns = constrained_columns
else:
referred_columns = list(
- self._find_cols_in_sig(referred_columns))
+ self._find_cols_in_sig(referred_columns)
+ )
referred_name = referred_quoted_name or referred_name
options = {}
for token in re.split(r" *\bON\b *", onupdatedelete.upper()):
if token.startswith("DELETE"):
- options['ondelete'] = token[6:].strip()
+ options["ondelete"] = token[6:].strip()
elif token.startswith("UPDATE"):
options["onupdate"] = token[6:].strip()
yield (
- constraint_name, constrained_columns,
- referred_name, referred_columns, options)
+ constraint_name,
+ constrained_columns,
+ referred_name,
+ referred_columns,
+ options,
+ )
+
fkeys = []
for (
- constraint_name, constrained_columns,
- referred_name, referred_columns, options) in parse_fks():
- sig = fk_sig(
- constrained_columns, referred_name, referred_columns)
+ constraint_name,
+ constrained_columns,
+ referred_name,
+ referred_columns,
+ options,
+ ) in parse_fks():
+ sig = fk_sig(constrained_columns, referred_name, referred_columns)
if sig not in keys_by_signature:
util.warn(
"WARNING: SQL-parsed foreign key constraint "
"'%s' could not be located in PRAGMA "
- "foreign_keys for table %s" % (
- sig,
- table_name
- ))
+ "foreign_keys for table %s" % (sig, table_name)
+ )
continue
key = keys_by_signature.pop(sig)
- key['name'] = constraint_name
- key['options'] = options
+ key["name"] = constraint_name
+ key["options"] = options
fkeys.append(key)
# assume the remainders are the unnamed, inline constraints, just
# use them as is as it's extremely difficult to parse inline
@@ -1632,20 +1848,26 @@ class SQLiteDialect(default.DefaultDialect):
yield match.group(1) or match.group(2)
@reflection.cache
- def get_unique_constraints(self, connection, table_name,
- schema=None, **kw):
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
auto_index_by_sig = {}
for idx in self.get_indexes(
- connection, table_name, schema=schema,
- include_auto_indexes=True, **kw):
- if not idx['name'].startswith("sqlite_autoindex"):
+ connection,
+ table_name,
+ schema=schema,
+ include_auto_indexes=True,
+ **kw
+ ):
+ if not idx["name"].startswith("sqlite_autoindex"):
continue
- sig = tuple(idx['column_names'])
+ sig = tuple(idx["column_names"])
auto_index_by_sig[sig] = idx
table_data = self._get_table_sql(
- connection, table_name, schema=schema, **kw)
+ connection, table_name, schema=schema, **kw
+ )
if not table_data:
return []
@@ -1654,8 +1876,8 @@ class SQLiteDialect(default.DefaultDialect):
def parse_uqs():
UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
INLINE_UNIQUE_PATTERN = (
- r'(?:(".+?")|([a-z0-9]+)) '
- r'+[a-z0-9_ ]+? +UNIQUE')
+ r'(?:(".+?")|([a-z0-9]+)) ' r"+[a-z0-9_ ]+? +UNIQUE"
+ )
for match in re.finditer(UNIQUE_PATTERN, table_data, re.I):
name, cols = match.group(1, 2)
@@ -1666,34 +1888,29 @@ class SQLiteDialect(default.DefaultDialect):
# are kind of the same thing :)
for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I):
cols = list(
- self._find_cols_in_sig(match.group(1) or match.group(2)))
+ self._find_cols_in_sig(match.group(1) or match.group(2))
+ )
yield None, cols
for name, cols in parse_uqs():
sig = tuple(cols)
if sig in auto_index_by_sig:
auto_index_by_sig.pop(sig)
- parsed_constraint = {
- 'name': name,
- 'column_names': cols
- }
+ parsed_constraint = {"name": name, "column_names": cols}
unique_constraints.append(parsed_constraint)
# NOTE: auto_index_by_sig might not be empty here,
# the PRIMARY KEY may have an entry.
return unique_constraints
@reflection.cache
- def get_check_constraints(self, connection, table_name,
- schema=None, **kw):
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
table_data = self._get_table_sql(
- connection, table_name, schema=schema, **kw)
+ connection, table_name, schema=schema, **kw
+ )
if not table_data:
return []
- CHECK_PATTERN = (
- r'(?:CONSTRAINT (\w+) +)?'
- r'CHECK *\( *(.+) *\),? *'
- )
+ CHECK_PATTERN = r"(?:CONSTRAINT (\w+) +)?" r"CHECK *\( *(.+) *\),? *"
check_constraints = []
# NOTE: we aren't using re.S here because we actually are
# taking advantage of each CHECK constraint being all on one
@@ -1701,25 +1918,26 @@ class SQLiteDialect(default.DefaultDialect):
# necessarily makes assumptions as to how the CREATE TABLE
# was emitted.
for match in re.finditer(CHECK_PATTERN, table_data, re.I):
- check_constraints.append({
- 'sqltext': match.group(2),
- 'name': match.group(1)
- })
+ check_constraints.append(
+ {"sqltext": match.group(2), "name": match.group(1)}
+ )
return check_constraints
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
pragma_indexes = self._get_table_pragma(
- connection, "index_list", table_name, schema=schema)
+ connection, "index_list", table_name, schema=schema
+ )
indexes = []
- include_auto_indexes = kw.pop('include_auto_indexes', False)
+ include_auto_indexes = kw.pop("include_auto_indexes", False)
for row in pragma_indexes:
# ignore implicit primary key index.
# http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html
- if (not include_auto_indexes and
- row[1].startswith('sqlite_autoindex')):
+ if not include_auto_indexes and row[1].startswith(
+ "sqlite_autoindex"
+ ):
continue
indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
@@ -1727,34 +1945,38 @@ class SQLiteDialect(default.DefaultDialect):
# loop thru unique indexes to get the column names.
for idx in indexes:
pragma_index = self._get_table_pragma(
- connection, "index_info", idx['name'])
+ connection, "index_info", idx["name"]
+ )
for row in pragma_index:
- idx['column_names'].append(row[2])
+ idx["column_names"].append(row[2])
return indexes
@reflection.cache
def _get_table_sql(self, connection, table_name, schema=None, **kw):
if schema:
schema_expr = "%s." % (
- self.identifier_preparer.quote_identifier(schema))
+ self.identifier_preparer.quote_identifier(schema)
+ )
else:
schema_expr = ""
try:
- s = ("SELECT sql FROM "
- " (SELECT * FROM %(schema)ssqlite_master UNION ALL "
- " SELECT * FROM %(schema)ssqlite_temp_master) "
- "WHERE name = '%(table)s' "
- "AND type = 'table'" % {
- "schema": schema_expr,
- "table": table_name})
+ s = (
+ "SELECT sql FROM "
+ " (SELECT * FROM %(schema)ssqlite_master UNION ALL "
+ " SELECT * FROM %(schema)ssqlite_temp_master) "
+ "WHERE name = '%(table)s' "
+ "AND type = 'table'"
+ % {"schema": schema_expr, "table": table_name}
+ )
rs = connection.execute(s)
except exc.DBAPIError:
- s = ("SELECT sql FROM %(schema)ssqlite_master "
- "WHERE name = '%(table)s' "
- "AND type = 'table'" % {
- "schema": schema_expr,
- "table": table_name})
+ s = (
+ "SELECT sql FROM %(schema)ssqlite_master "
+ "WHERE name = '%(table)s' "
+ "AND type = 'table'"
+ % {"schema": schema_expr, "table": table_name}
+ )
rs = connection.execute(s)
return rs.scalar()
diff --git a/lib/sqlalchemy/dialects/sqlite/json.py b/lib/sqlalchemy/dialects/sqlite/json.py
index 90929fbd8..db185dd4d 100644
--- a/lib/sqlalchemy/dialects/sqlite/json.py
+++ b/lib/sqlalchemy/dialects/sqlite/json.py
@@ -58,7 +58,6 @@ class _FormatTypeMixin(object):
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
-
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
@@ -70,8 +69,10 @@ class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
- "".join([
- "[%s]" % elem if isinstance(elem, int)
- else '."%s"' % elem for elem in value
- ])
+ "".join(
+ [
+ "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+ for elem in value
+ ]
+ )
)
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
index 09f2b8009..fca425127 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
@@ -82,9 +82,9 @@ from ... import pool
class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
- driver = 'pysqlcipher'
+ driver = "pysqlcipher"
- pragmas = ('kdf_iter', 'cipher', 'cipher_page_size', 'cipher_use_hmac')
+ pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac")
@classmethod
def dbapi(cls):
@@ -102,15 +102,13 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
return pool.SingletonThreadPool
def connect(self, *cargs, **cparams):
- passphrase = cparams.pop('passphrase', '')
+ passphrase = cparams.pop("passphrase", "")
- pragmas = dict(
- (key, cparams.pop(key, None)) for key in
- self.pragmas
- )
+ pragmas = dict((key, cparams.pop(key, None)) for key in self.pragmas)
- conn = super(SQLiteDialect_pysqlcipher, self).\
- connect(*cargs, **cparams)
+ conn = super(SQLiteDialect_pysqlcipher, self).connect(
+ *cargs, **cparams
+ )
conn.execute('pragma key="%s"' % passphrase)
for prag, value in pragmas.items():
if value is not None:
@@ -120,11 +118,17 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
def create_connect_args(self, url):
super_url = _url.URL(
- url.drivername, username=url.username,
- host=url.host, database=url.database, query=url.query)
- c_args, opts = super(SQLiteDialect_pysqlcipher, self).\
- create_connect_args(super_url)
- opts['passphrase'] = url.password
+ url.drivername,
+ username=url.username,
+ host=url.host,
+ database=url.database,
+ query=url.query,
+ )
+ c_args, opts = super(
+ SQLiteDialect_pysqlcipher, self
+ ).create_connect_args(super_url)
+ opts["passphrase"] = url.password
return c_args, opts
+
dialect = SQLiteDialect_pysqlcipher
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
index 8809962df..e78d76ae6 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -301,20 +301,20 @@ class _SQLite_pysqliteDate(DATE):
class SQLiteDialect_pysqlite(SQLiteDialect):
- default_paramstyle = 'qmark'
+ default_paramstyle = "qmark"
colspecs = util.update_copy(
SQLiteDialect.colspecs,
{
sqltypes.Date: _SQLite_pysqliteDate,
sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
- }
+ },
)
if not util.py2k:
description_encoding = None
- driver = 'pysqlite'
+ driver = "pysqlite"
def __init__(self, **kwargs):
SQLiteDialect.__init__(self, **kwargs)
@@ -323,10 +323,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
sqlite_ver = self.dbapi.version_info
if sqlite_ver < (2, 1, 3):
util.warn(
- ("The installed version of pysqlite2 (%s) is out-dated "
- "and will cause errors in some cases. Version 2.1.3 "
- "or greater is recommended.") %
- '.'.join([str(subver) for subver in sqlite_ver]))
+ (
+ "The installed version of pysqlite2 (%s) is out-dated "
+ "and will cause errors in some cases. Version 2.1.3 "
+ "or greater is recommended."
+ )
+ % ".".join([str(subver) for subver in sqlite_ver])
+ )
@classmethod
def dbapi(cls):
@@ -341,7 +344,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
@classmethod
def get_pool_class(cls, url):
- if url.database and url.database != ':memory:':
+ if url.database and url.database != ":memory:":
return pool.NullPool
else:
return pool.SingletonThreadPool
@@ -356,22 +359,25 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
"Valid SQLite URL forms are:\n"
" sqlite:///:memory: (or, sqlite://)\n"
" sqlite:///relative/path/to/file.db\n"
- " sqlite:////absolute/path/to/file.db" % (url,))
- filename = url.database or ':memory:'
- if filename != ':memory:':
+ " sqlite:////absolute/path/to/file.db" % (url,)
+ )
+ filename = url.database or ":memory:"
+ if filename != ":memory:":
filename = os.path.abspath(filename)
opts = url.query.copy()
- util.coerce_kw_type(opts, 'timeout', float)
- util.coerce_kw_type(opts, 'isolation_level', str)
- util.coerce_kw_type(opts, 'detect_types', int)
- util.coerce_kw_type(opts, 'check_same_thread', bool)
- util.coerce_kw_type(opts, 'cached_statements', int)
+ util.coerce_kw_type(opts, "timeout", float)
+ util.coerce_kw_type(opts, "isolation_level", str)
+ util.coerce_kw_type(opts, "detect_types", int)
+ util.coerce_kw_type(opts, "check_same_thread", bool)
+ util.coerce_kw_type(opts, "cached_statements", int)
return ([filename], opts)
def is_disconnect(self, e, connection, cursor):
- return isinstance(e, self.dbapi.ProgrammingError) and \
- "Cannot operate on a closed database." in str(e)
+ return isinstance(
+ e, self.dbapi.ProgrammingError
+ ) and "Cannot operate on a closed database." in str(e)
+
dialect = SQLiteDialect_pysqlite
diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py
index be434977f..2f55d3bf6 100644
--- a/lib/sqlalchemy/dialects/sybase/__init__.py
+++ b/lib/sqlalchemy/dialects/sybase/__init__.py
@@ -7,21 +7,61 @@
from . import base, pysybase, pyodbc # noqa
-from .base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
- TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
- BIGINT, INT, INTEGER, SMALLINT, BINARY,\
- VARBINARY, UNITEXT, UNICHAR, UNIVARCHAR,\
- IMAGE, BIT, MONEY, SMALLMONEY, TINYINT
+from .base import (
+ CHAR,
+ VARCHAR,
+ TIME,
+ NCHAR,
+ NVARCHAR,
+ TEXT,
+ DATE,
+ DATETIME,
+ FLOAT,
+ NUMERIC,
+ BIGINT,
+ INT,
+ INTEGER,
+ SMALLINT,
+ BINARY,
+ VARBINARY,
+ UNITEXT,
+ UNICHAR,
+ UNIVARCHAR,
+ IMAGE,
+ BIT,
+ MONEY,
+ SMALLMONEY,
+ TINYINT,
+)
# default dialect
base.dialect = dialect = pyodbc.dialect
__all__ = (
- 'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR',
- 'TEXT', 'DATE', 'DATETIME', 'FLOAT', 'NUMERIC',
- 'BIGINT', 'INT', 'INTEGER', 'SMALLINT', 'BINARY',
- 'VARBINARY', 'UNITEXT', 'UNICHAR', 'UNIVARCHAR',
- 'IMAGE', 'BIT', 'MONEY', 'SMALLMONEY', 'TINYINT',
- 'dialect'
+ "CHAR",
+ "VARCHAR",
+ "TIME",
+ "NCHAR",
+ "NVARCHAR",
+ "TEXT",
+ "DATE",
+ "DATETIME",
+ "FLOAT",
+ "NUMERIC",
+ "BIGINT",
+ "INT",
+ "INTEGER",
+ "SMALLINT",
+ "BINARY",
+ "VARBINARY",
+ "UNITEXT",
+ "UNICHAR",
+ "UNIVARCHAR",
+ "IMAGE",
+ "BIT",
+ "MONEY",
+ "SMALLMONEY",
+ "TINYINT",
+ "dialect",
)
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
index 7dd973573..1214a9279 100644
--- a/lib/sqlalchemy/dialects/sybase/base.py
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -31,70 +31,257 @@ from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import schema as sa_schema
from sqlalchemy import util, sql, exc
-from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
- TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
- BIGINT, INT, INTEGER, SMALLINT, BINARY,\
- VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
- UnicodeText, REAL
-
-RESERVED_WORDS = set([
- "add", "all", "alter", "and",
- "any", "as", "asc", "backup",
- "begin", "between", "bigint", "binary",
- "bit", "bottom", "break", "by",
- "call", "capability", "cascade", "case",
- "cast", "char", "char_convert", "character",
- "check", "checkpoint", "close", "comment",
- "commit", "connect", "constraint", "contains",
- "continue", "convert", "create", "cross",
- "cube", "current", "current_timestamp", "current_user",
- "cursor", "date", "dbspace", "deallocate",
- "dec", "decimal", "declare", "default",
- "delete", "deleting", "desc", "distinct",
- "do", "double", "drop", "dynamic",
- "else", "elseif", "encrypted", "end",
- "endif", "escape", "except", "exception",
- "exec", "execute", "existing", "exists",
- "externlogin", "fetch", "first", "float",
- "for", "force", "foreign", "forward",
- "from", "full", "goto", "grant",
- "group", "having", "holdlock", "identified",
- "if", "in", "index", "index_lparen",
- "inner", "inout", "insensitive", "insert",
- "inserting", "install", "instead", "int",
- "integer", "integrated", "intersect", "into",
- "iq", "is", "isolation", "join",
- "key", "lateral", "left", "like",
- "lock", "login", "long", "match",
- "membership", "message", "mode", "modify",
- "natural", "new", "no", "noholdlock",
- "not", "notify", "null", "numeric",
- "of", "off", "on", "open",
- "option", "options", "or", "order",
- "others", "out", "outer", "over",
- "passthrough", "precision", "prepare", "primary",
- "print", "privileges", "proc", "procedure",
- "publication", "raiserror", "readtext", "real",
- "reference", "references", "release", "remote",
- "remove", "rename", "reorganize", "resource",
- "restore", "restrict", "return", "revoke",
- "right", "rollback", "rollup", "save",
- "savepoint", "scroll", "select", "sensitive",
- "session", "set", "setuser", "share",
- "smallint", "some", "sqlcode", "sqlstate",
- "start", "stop", "subtrans", "subtransaction",
- "synchronize", "syntax_error", "table", "temporary",
- "then", "time", "timestamp", "tinyint",
- "to", "top", "tran", "trigger",
- "truncate", "tsequal", "unbounded", "union",
- "unique", "unknown", "unsigned", "update",
- "updating", "user", "using", "validate",
- "values", "varbinary", "varchar", "variable",
- "varying", "view", "wait", "waitfor",
- "when", "where", "while", "window",
- "with", "with_cube", "with_lparen", "with_rollup",
- "within", "work", "writetext",
-])
+from sqlalchemy.types import (
+ CHAR,
+ VARCHAR,
+ TIME,
+ NCHAR,
+ NVARCHAR,
+ TEXT,
+ DATE,
+ DATETIME,
+ FLOAT,
+ NUMERIC,
+ BIGINT,
+ INT,
+ INTEGER,
+ SMALLINT,
+ BINARY,
+ VARBINARY,
+ DECIMAL,
+ TIMESTAMP,
+ Unicode,
+ UnicodeText,
+ REAL,
+)
+
+RESERVED_WORDS = set(
+ [
+ "add",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "backup",
+ "begin",
+ "between",
+ "bigint",
+ "binary",
+ "bit",
+ "bottom",
+ "break",
+ "by",
+ "call",
+ "capability",
+ "cascade",
+ "case",
+ "cast",
+ "char",
+ "char_convert",
+ "character",
+ "check",
+ "checkpoint",
+ "close",
+ "comment",
+ "commit",
+ "connect",
+ "constraint",
+ "contains",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "cube",
+ "current",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "date",
+ "dbspace",
+ "deallocate",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delete",
+ "deleting",
+ "desc",
+ "distinct",
+ "do",
+ "double",
+ "drop",
+ "dynamic",
+ "else",
+ "elseif",
+ "encrypted",
+ "end",
+ "endif",
+ "escape",
+ "except",
+ "exception",
+ "exec",
+ "execute",
+ "existing",
+ "exists",
+ "externlogin",
+ "fetch",
+ "first",
+ "float",
+ "for",
+ "force",
+ "foreign",
+ "forward",
+ "from",
+ "full",
+ "goto",
+ "grant",
+ "group",
+ "having",
+ "holdlock",
+ "identified",
+ "if",
+ "in",
+ "index",
+ "index_lparen",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "inserting",
+ "install",
+ "instead",
+ "int",
+ "integer",
+ "integrated",
+ "intersect",
+ "into",
+ "iq",
+ "is",
+ "isolation",
+ "join",
+ "key",
+ "lateral",
+ "left",
+ "like",
+ "lock",
+ "login",
+ "long",
+ "match",
+ "membership",
+ "message",
+ "mode",
+ "modify",
+ "natural",
+ "new",
+ "no",
+ "noholdlock",
+ "not",
+ "notify",
+ "null",
+ "numeric",
+ "of",
+ "off",
+ "on",
+ "open",
+ "option",
+ "options",
+ "or",
+ "order",
+ "others",
+ "out",
+ "outer",
+ "over",
+ "passthrough",
+ "precision",
+ "prepare",
+ "primary",
+ "print",
+ "privileges",
+ "proc",
+ "procedure",
+ "publication",
+ "raiserror",
+ "readtext",
+ "real",
+ "reference",
+ "references",
+ "release",
+ "remote",
+ "remove",
+ "rename",
+ "reorganize",
+ "resource",
+ "restore",
+ "restrict",
+ "return",
+ "revoke",
+ "right",
+ "rollback",
+ "rollup",
+ "save",
+ "savepoint",
+ "scroll",
+ "select",
+ "sensitive",
+ "session",
+ "set",
+ "setuser",
+ "share",
+ "smallint",
+ "some",
+ "sqlcode",
+ "sqlstate",
+ "start",
+ "stop",
+ "subtrans",
+ "subtransaction",
+ "synchronize",
+ "syntax_error",
+ "table",
+ "temporary",
+ "then",
+ "time",
+ "timestamp",
+ "tinyint",
+ "to",
+ "top",
+ "tran",
+ "trigger",
+ "truncate",
+ "tsequal",
+ "unbounded",
+ "union",
+ "unique",
+ "unknown",
+ "unsigned",
+ "update",
+ "updating",
+ "user",
+ "using",
+ "validate",
+ "values",
+ "varbinary",
+ "varchar",
+ "variable",
+ "varying",
+ "view",
+ "wait",
+ "waitfor",
+ "when",
+ "where",
+ "while",
+ "window",
+ "with",
+ "with_cube",
+ "with_lparen",
+ "with_rollup",
+ "within",
+ "work",
+ "writetext",
+ ]
+)
class _SybaseUnitypeMixin(object):
@@ -106,27 +293,28 @@ class _SybaseUnitypeMixin(object):
return str(value) # decode("ucs-2")
else:
return None
+
return process
class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
- __visit_name__ = 'UNICHAR'
+ __visit_name__ = "UNICHAR"
class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
- __visit_name__ = 'UNIVARCHAR'
+ __visit_name__ = "UNIVARCHAR"
class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText):
- __visit_name__ = 'UNITEXT'
+ __visit_name__ = "UNITEXT"
class TINYINT(sqltypes.Integer):
- __visit_name__ = 'TINYINT'
+ __visit_name__ = "TINYINT"
class BIT(sqltypes.TypeEngine):
- __visit_name__ = 'BIT'
+ __visit_name__ = "BIT"
class MONEY(sqltypes.TypeEngine):
@@ -142,7 +330,7 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
class IMAGE(sqltypes.LargeBinary):
- __visit_name__ = 'IMAGE'
+ __visit_name__ = "IMAGE"
class SybaseTypeCompiler(compiler.GenericTypeCompiler):
@@ -182,67 +370,66 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler):
def visit_UNIQUEIDENTIFIER(self, type_, **kw):
return "UNIQUEIDENTIFIER"
-ischema_names = {
- 'bigint': BIGINT,
- 'int': INTEGER,
- 'integer': INTEGER,
- 'smallint': SMALLINT,
- 'tinyint': TINYINT,
- 'unsigned bigint': BIGINT, # TODO: unsigned flags
- 'unsigned int': INTEGER, # TODO: unsigned flags
- 'unsigned smallint': SMALLINT, # TODO: unsigned flags
- 'numeric': NUMERIC,
- 'decimal': DECIMAL,
- 'dec': DECIMAL,
- 'float': FLOAT,
- 'double': NUMERIC, # TODO
- 'double precision': NUMERIC, # TODO
- 'real': REAL,
- 'smallmoney': SMALLMONEY,
- 'money': MONEY,
- 'smalldatetime': DATETIME,
- 'datetime': DATETIME,
- 'date': DATE,
- 'time': TIME,
- 'char': CHAR,
- 'character': CHAR,
- 'varchar': VARCHAR,
- 'character varying': VARCHAR,
- 'char varying': VARCHAR,
- 'unichar': UNICHAR,
- 'unicode character': UNIVARCHAR,
- 'nchar': NCHAR,
- 'national char': NCHAR,
- 'national character': NCHAR,
- 'nvarchar': NVARCHAR,
- 'nchar varying': NVARCHAR,
- 'national char varying': NVARCHAR,
- 'national character varying': NVARCHAR,
- 'text': TEXT,
- 'unitext': UNITEXT,
- 'binary': BINARY,
- 'varbinary': VARBINARY,
- 'image': IMAGE,
- 'bit': BIT,
+ischema_names = {
+ "bigint": BIGINT,
+ "int": INTEGER,
+ "integer": INTEGER,
+ "smallint": SMALLINT,
+ "tinyint": TINYINT,
+ "unsigned bigint": BIGINT, # TODO: unsigned flags
+ "unsigned int": INTEGER, # TODO: unsigned flags
+ "unsigned smallint": SMALLINT, # TODO: unsigned flags
+ "numeric": NUMERIC,
+ "decimal": DECIMAL,
+ "dec": DECIMAL,
+ "float": FLOAT,
+ "double": NUMERIC, # TODO
+ "double precision": NUMERIC, # TODO
+ "real": REAL,
+ "smallmoney": SMALLMONEY,
+ "money": MONEY,
+ "smalldatetime": DATETIME,
+ "datetime": DATETIME,
+ "date": DATE,
+ "time": TIME,
+ "char": CHAR,
+ "character": CHAR,
+ "varchar": VARCHAR,
+ "character varying": VARCHAR,
+ "char varying": VARCHAR,
+ "unichar": UNICHAR,
+ "unicode character": UNIVARCHAR,
+ "nchar": NCHAR,
+ "national char": NCHAR,
+ "national character": NCHAR,
+ "nvarchar": NVARCHAR,
+ "nchar varying": NVARCHAR,
+ "national char varying": NVARCHAR,
+ "national character varying": NVARCHAR,
+ "text": TEXT,
+ "unitext": UNITEXT,
+ "binary": BINARY,
+ "varbinary": VARBINARY,
+ "image": IMAGE,
+ "bit": BIT,
# not in documentation for ASE 15.7
- 'long varchar': TEXT, # TODO
- 'timestamp': TIMESTAMP,
- 'uniqueidentifier': UNIQUEIDENTIFIER,
-
+ "long varchar": TEXT, # TODO
+ "timestamp": TIMESTAMP,
+ "uniqueidentifier": UNIQUEIDENTIFIER,
}
class SybaseInspector(reflection.Inspector):
-
def __init__(self, conn):
reflection.Inspector.__init__(self, conn)
def get_table_id(self, table_name, schema=None):
"""Return the table id from `table_name` and `schema`."""
- return self.dialect.get_table_id(self.bind, table_name, schema,
- info_cache=self.info_cache)
+ return self.dialect.get_table_id(
+ self.bind, table_name, schema, info_cache=self.info_cache
+ )
class SybaseExecutionContext(default.DefaultExecutionContext):
@@ -267,15 +454,17 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
insert_has_sequence = seq_column is not None
if insert_has_sequence:
- self._enable_identity_insert = \
+ self._enable_identity_insert = (
seq_column.key in self.compiled_parameters[0]
+ )
else:
self._enable_identity_insert = False
if self._enable_identity_insert:
self.cursor.execute(
- "SET IDENTITY_INSERT %s ON" %
- self.dialect.identifier_preparer.format_table(tbl))
+ "SET IDENTITY_INSERT %s ON"
+ % self.dialect.identifier_preparer.format_table(tbl)
+ )
if self.isddl:
# TODO: to enhance this, we can detect "ddl in tran" on the
@@ -284,14 +473,16 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
if not self.should_autocommit:
raise exc.InvalidRequestError(
"The Sybase dialect only supports "
- "DDL in 'autocommit' mode at this time.")
+ "DDL in 'autocommit' mode at this time."
+ )
self.root_connection.engine.logger.info(
- "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')")
+ "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')"
+ )
self.set_ddl_autocommit(
- self.root_connection.connection.connection,
- True)
+ self.root_connection.connection.connection, True
+ )
def post_exec(self):
if self.isddl:
@@ -299,9 +490,10 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
if self._enable_identity_insert:
self.cursor.execute(
- "SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer.
- format_table(self.compiled.statement.table)
+ "SET IDENTITY_INSERT %s OFF"
+ % self.dialect.identifier_preparer.format_table(
+ self.compiled.statement.table
+ )
)
def get_lastrowid(self):
@@ -317,11 +509,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
extract_map = util.update_copy(
compiler.SQLCompiler.extract_map,
- {
- 'doy': 'dayofyear',
- 'dow': 'weekday',
- 'milliseconds': 'millisecond'
- })
+ {"doy": "dayofyear", "dow": "weekday", "milliseconds": "millisecond"},
+ )
def get_select_precolumns(self, select, **kw):
s = select._distinct and "DISTINCT " or ""
@@ -330,9 +519,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
limit = select._limit
if limit:
# if select._limit == 1:
- # s += "FIRST "
+ # s += "FIRST "
# else:
- # s += "TOP %s " % (select._limit,)
+ # s += "TOP %s " % (select._limit,)
s += "TOP %s " % (limit,)
offset = select._offset
if offset:
@@ -348,8 +537,7 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
- return 'DATEPART("%s", %s)' % (
- field, self.process(extract.expr, **kw))
+ return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
def visit_now_func(self, fn, **kw):
return "GETDATE()"
@@ -357,10 +545,10 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR"
# which SQLAlchemy doesn't use
- return ''
+ return ""
def order_by_clause(self, select, **kw):
- kw['literal_binds'] = True
+ kw["literal_binds"] = True
order_by = self.process(select._order_by_clause, **kw)
# SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
@@ -369,8 +557,7 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
else:
return ""
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
@@ -379,34 +566,41 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
self, asfrom=True, iscrud=True, ashint=ashint
)
- def delete_extra_from_clause(self, delete_stmt, from_table,
- extra_froms, from_hints, **kw):
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Render the DELETE .. FROM clause specific to Sybase."""
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in [from_table] + extra_froms)
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
class SybaseDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
if column.table is None:
raise exc.CompileError(
"The Sybase dialect requires Table-bound "
- "columns in order to generate DDL")
+ "columns in order to generate DDL"
+ )
seq_col = column.table._autoincrement_column
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if seq_col is column:
- sequence = isinstance(column.default, sa_schema.Sequence) \
+ sequence = (
+ isinstance(column.default, sa_schema.Sequence)
and column.default
+ )
if sequence:
- start, increment = sequence.start or 1, \
- sequence.increment or 1
+ start, increment = sequence.start or 1, sequence.increment or 1
else:
start, increment = 1, 1
if (start, increment) == (1, 1):
@@ -431,8 +625,7 @@ class SybaseDDLCompiler(compiler.DDLCompiler):
index = drop.element
return "\nDROP INDEX %s.%s" % (
self.preparer.quote_identifier(index.table.name),
- self._prepared_index_name(drop.element,
- include_schema=False)
+ self._prepared_index_name(drop.element, include_schema=False),
)
@@ -441,7 +634,7 @@ class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
class SybaseDialect(default.DefaultDialect):
- name = 'sybase'
+ name = "sybase"
supports_unicode_statements = False
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
@@ -463,14 +656,18 @@ class SybaseDialect(default.DefaultDialect):
def _get_default_schema_name(self, connection):
return connection.scalar(
- text("SELECT user_name() as user_name",
- typemap={'user_name': Unicode})
+ text(
+ "SELECT user_name() as user_name",
+ typemap={"user_name": Unicode},
+ )
)
def initialize(self, connection):
super(SybaseDialect, self).initialize(connection)
- if self.server_version_info is not None and\
- self.server_version_info < (15, ):
+ if (
+ self.server_version_info is not None
+ and self.server_version_info < (15,)
+ ):
self.max_identifier_length = 30
else:
self.max_identifier_length = 255
@@ -488,22 +685,24 @@ class SybaseDialect(default.DefaultDialect):
if schema is None:
schema = self.default_schema_name
- TABLEID_SQL = text("""
+ TABLEID_SQL = text(
+ """
SELECT o.id AS id
FROM sysobjects o JOIN sysusers u ON o.uid=u.uid
WHERE u.name = :schema_name
AND o.name = :table_name
AND o.type in ('U', 'V')
- """)
+ """
+ )
if util.py2k:
if isinstance(schema, unicode):
schema = schema.encode("ascii")
if isinstance(table_name, unicode):
table_name = table_name.encode("ascii")
- result = connection.execute(TABLEID_SQL,
- schema_name=schema,
- table_name=table_name)
+ result = connection.execute(
+ TABLEID_SQL, schema_name=schema, table_name=table_name
+ )
table_id = result.scalar()
if table_id is None:
raise exc.NoSuchTableError(table_name)
@@ -511,10 +710,12 @@ class SybaseDialect(default.DefaultDialect):
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
- table_id = self.get_table_id(connection, table_name, schema,
- info_cache=kw.get("info_cache"))
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
- COLUMN_SQL = text("""
+ COLUMN_SQL = text(
+ """
SELECT col.name AS name,
t.name AS type,
(col.status & 8) AS nullable,
@@ -528,23 +729,47 @@ class SybaseDialect(default.DefaultDialect):
WHERE col.usertype = t.usertype
AND col.id = :table_id
ORDER BY col.colid
- """)
+ """
+ )
results = connection.execute(COLUMN_SQL, table_id=table_id)
columns = []
- for (name, type_, nullable, autoincrement, default, precision, scale,
- length) in results:
- col_info = self._get_column_info(name, type_, bool(nullable),
- bool(autoincrement),
- default, precision, scale,
- length)
+ for (
+ name,
+ type_,
+ nullable,
+ autoincrement,
+ default,
+ precision,
+ scale,
+ length,
+ ) in results:
+ col_info = self._get_column_info(
+ name,
+ type_,
+ bool(nullable),
+ bool(autoincrement),
+ default,
+ precision,
+ scale,
+ length,
+ )
columns.append(col_info)
return columns
- def _get_column_info(self, name, type_, nullable, autoincrement, default,
- precision, scale, length):
+ def _get_column_info(
+ self,
+ name,
+ type_,
+ nullable,
+ autoincrement,
+ default,
+ precision,
+ scale,
+ length,
+ ):
coltype = self.ischema_names.get(type_, None)
@@ -565,8 +790,9 @@ class SybaseDialect(default.DefaultDialect):
# if is_array:
# coltype = ARRAY(coltype)
else:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (type_, name))
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (type_, name)
+ )
coltype = sqltypes.NULLTYPE
if default:
@@ -575,15 +801,21 @@ class SybaseDialect(default.DefaultDialect):
else:
default = None
- column_info = dict(name=name, type=coltype, nullable=nullable,
- default=default, autoincrement=autoincrement)
+ column_info = dict(
+ name=name,
+ type=coltype,
+ nullable=nullable,
+ default=default,
+ autoincrement=autoincrement,
+ )
return column_info
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
- table_id = self.get_table_id(connection, table_name, schema,
- info_cache=kw.get("info_cache"))
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
table_cache = {}
column_cache = {}
@@ -591,11 +823,13 @@ class SybaseDialect(default.DefaultDialect):
table_cache[table_id] = {"name": table_name, "schema": schema}
- COLUMN_SQL = text("""
+ COLUMN_SQL = text(
+ """
SELECT c.colid AS id, c.name AS name
FROM syscolumns c
WHERE c.id = :table_id
- """)
+ """
+ )
results = connection.execute(COLUMN_SQL, table_id=table_id)
columns = {}
@@ -603,7 +837,8 @@ class SybaseDialect(default.DefaultDialect):
columns[col["id"]] = col["name"]
column_cache[table_id] = columns
- REFCONSTRAINT_SQL = text("""
+ REFCONSTRAINT_SQL = text(
+ """
SELECT o.name AS name, r.reftabid AS reftable_id,
r.keycnt AS 'count',
r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3,
@@ -621,15 +856,19 @@ class SybaseDialect(default.DefaultDialect):
r.refkey16 AS refkey16
FROM sysreferences r JOIN sysobjects o on r.tableid = o.id
WHERE r.tableid = :table_id
- """)
+ """
+ )
referential_constraints = connection.execute(
- REFCONSTRAINT_SQL, table_id=table_id).fetchall()
+ REFCONSTRAINT_SQL, table_id=table_id
+ ).fetchall()
- REFTABLE_SQL = text("""
+ REFTABLE_SQL = text(
+ """
SELECT o.name AS name, u.name AS 'schema'
FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
WHERE o.id = :table_id
- """)
+ """
+ )
for r in referential_constraints:
reftable_id = r["reftable_id"]
@@ -639,8 +878,10 @@ class SybaseDialect(default.DefaultDialect):
reftable = c.fetchone()
c.close()
table_info = {"name": reftable["name"], "schema": None}
- if (schema is not None or
- reftable["schema"] != self.default_schema_name):
+ if (
+ schema is not None
+ or reftable["schema"] != self.default_schema_name
+ ):
table_info["schema"] = reftable["schema"]
table_cache[reftable_id] = table_info
@@ -664,7 +905,7 @@ class SybaseDialect(default.DefaultDialect):
"referred_schema": reftable["schema"],
"referred_table": reftable["name"],
"referred_columns": referred_columns,
- "name": r["name"]
+ "name": r["name"],
}
foreign_keys.append(fk_info)
@@ -673,10 +914,12 @@ class SybaseDialect(default.DefaultDialect):
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
- table_id = self.get_table_id(connection, table_name, schema,
- info_cache=kw.get("info_cache"))
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
- INDEX_SQL = text("""
+ INDEX_SQL = text(
+ """
SELECT object_name(i.id) AS table_name,
i.keycnt AS 'count',
i.name AS name,
@@ -702,7 +945,8 @@ class SybaseDialect(default.DefaultDialect):
AND o.id = :table_id
AND (i.status & 2048) = 0
AND i.indid BETWEEN 1 AND 254
- """)
+ """
+ )
results = connection.execute(INDEX_SQL, table_id=table_id)
indexes = []
@@ -710,19 +954,23 @@ class SybaseDialect(default.DefaultDialect):
column_names = []
for i in range(1, r["count"]):
column_names.append(r["col_%i" % (i,)])
- index_info = {"name": r["name"],
- "unique": bool(r["unique"]),
- "column_names": column_names}
+ index_info = {
+ "name": r["name"],
+ "unique": bool(r["unique"]),
+ "column_names": column_names,
+ }
indexes.append(index_info)
return indexes
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- table_id = self.get_table_id(connection, table_name, schema,
- info_cache=kw.get("info_cache"))
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
- PK_SQL = text("""
+ PK_SQL = text(
+ """
SELECT object_name(i.id) AS table_name,
i.keycnt AS 'count',
i.name AS name,
@@ -747,7 +995,8 @@ class SybaseDialect(default.DefaultDialect):
AND o.id = :table_id
AND (i.status & 2048) = 2048
AND i.indid BETWEEN 1 AND 254
- """)
+ """
+ )
results = connection.execute(PK_SQL, table_id=table_id)
pks = results.fetchone()
@@ -757,8 +1006,10 @@ class SybaseDialect(default.DefaultDialect):
if pks:
for i in range(1, pks["count"] + 1):
constrained_columns.append(pks["pk_%i" % (i,)])
- return {"constrained_columns": constrained_columns,
- "name": pks["name"]}
+ return {
+ "constrained_columns": constrained_columns,
+ "name": pks["name"],
+ }
else:
return {"constrained_columns": [], "name": None}
@@ -776,12 +1027,14 @@ class SybaseDialect(default.DefaultDialect):
if schema is None:
schema = self.default_schema_name
- TABLE_SQL = text("""
+ TABLE_SQL = text(
+ """
SELECT o.name AS name
FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
WHERE u.name = :schema_name
AND o.type = 'U'
- """)
+ """
+ )
if util.py2k:
if isinstance(schema, unicode):
@@ -796,12 +1049,14 @@ class SybaseDialect(default.DefaultDialect):
if schema is None:
schema = self.default_schema_name
- VIEW_DEF_SQL = text("""
+ VIEW_DEF_SQL = text(
+ """
SELECT c.text
FROM syscomments c JOIN sysobjects o ON c.id = o.id
WHERE o.name = :view_name
AND o.type = 'V'
- """)
+ """
+ )
if util.py2k:
if isinstance(view_name, unicode):
@@ -816,12 +1071,14 @@ class SybaseDialect(default.DefaultDialect):
if schema is None:
schema = self.default_schema_name
- VIEW_SQL = text("""
+ VIEW_SQL = text(
+ """
SELECT o.name AS name
FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
WHERE u.name = :schema_name
AND o.type = 'V'
- """)
+ """
+ )
if util.py2k:
if isinstance(schema, unicode):
diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py
index ddb6b7e21..eeceb359b 100644
--- a/lib/sqlalchemy/dialects/sybase/mxodbc.py
+++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py
@@ -30,4 +30,5 @@ class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_mxodbc
+
dialect = SybaseDialect_mxodbc
diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py
index af6469dad..a4759428c 100644
--- a/lib/sqlalchemy/dialects/sybase/pyodbc.py
+++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py
@@ -34,8 +34,10 @@ Currently *not* supported are::
"""
-from sqlalchemy.dialects.sybase.base import SybaseDialect,\
- SybaseExecutionContext
+from sqlalchemy.dialects.sybase.base import (
+ SybaseDialect,
+ SybaseExecutionContext,
+)
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy import types as sqltypes, processors
import decimal
@@ -51,12 +53,10 @@ class _SybNumeric_pyodbc(sqltypes.Numeric):
"""
def bind_processor(self, dialect):
- super_process = super(_SybNumeric_pyodbc, self).\
- bind_processor(dialect)
+ super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect)
def process(value):
- if self.asdecimal and \
- isinstance(value, decimal.Decimal):
+ if self.asdecimal and isinstance(value, decimal.Decimal):
if value.adjusted() < -6:
return processors.to_float(value)
@@ -65,6 +65,7 @@ class _SybNumeric_pyodbc(sqltypes.Numeric):
return super_process(value)
else:
return value
+
return process
@@ -79,8 +80,7 @@ class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_pyodbc
- colspecs = {
- sqltypes.Numeric: _SybNumeric_pyodbc,
- }
+ colspecs = {sqltypes.Numeric: _SybNumeric_pyodbc}
+
dialect = SybaseDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py
index 2168d5572..09d2cf380 100644
--- a/lib/sqlalchemy/dialects/sybase/pysybase.py
+++ b/lib/sqlalchemy/dialects/sybase/pysybase.py
@@ -22,8 +22,11 @@ kind at this time.
"""
from sqlalchemy import types as sqltypes, processors
-from sqlalchemy.dialects.sybase.base import SybaseDialect, \
- SybaseExecutionContext, SybaseSQLCompiler
+from sqlalchemy.dialects.sybase.base import (
+ SybaseDialect,
+ SybaseExecutionContext,
+ SybaseSQLCompiler,
+)
class _SybNumeric(sqltypes.Numeric):
@@ -35,7 +38,6 @@ class _SybNumeric(sqltypes.Numeric):
class SybaseExecutionContext_pysybase(SybaseExecutionContext):
-
def set_ddl_autocommit(self, dbapi_connection, value):
if value:
# call commit() on the Sybase connection directly,
@@ -58,24 +60,22 @@ class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
class SybaseDialect_pysybase(SybaseDialect):
- driver = 'pysybase'
+ driver = "pysybase"
execution_ctx_cls = SybaseExecutionContext_pysybase
statement_compiler = SybaseSQLCompiler_pysybase
- colspecs = {
- sqltypes.Numeric: _SybNumeric,
- sqltypes.Float: sqltypes.Float
- }
+ colspecs = {sqltypes.Numeric: _SybNumeric, sqltypes.Float: sqltypes.Float}
@classmethod
def dbapi(cls):
import Sybase
+
return Sybase
def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user', password='passwd')
+ opts = url.translate_connect_args(username="user", password="passwd")
- return ([opts.pop('host')], opts)
+ return ([opts.pop("host")], opts)
def do_executemany(self, cursor, statement, parameters, context=None):
# calling python-sybase executemany yields:
@@ -90,13 +90,17 @@ class SybaseDialect_pysybase(SybaseDialect):
return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
def is_disconnect(self, e, connection, cursor):
- if isinstance(e, (self.dbapi.OperationalError,
- self.dbapi.ProgrammingError)):
+ if isinstance(
+ e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
+ ):
msg = str(e)
- return ('Unable to complete network request to host' in msg or
- 'Invalid connection state' in msg or
- 'Invalid cursor state' in msg)
+ return (
+ "Unable to complete network request to host" in msg
+ or "Invalid connection state" in msg
+ or "Invalid cursor state" in msg
+ )
else:
return False
+
dialect = SybaseDialect_pysybase
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 6342b3c21..590359c38 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -57,10 +57,9 @@ from .interfaces import (
Dialect,
ExecutionContext,
ExceptionContext,
-
# backwards compat
Compiled,
- TypeCompiler
+ TypeCompiler,
)
from .base import (
@@ -82,9 +81,7 @@ from .result import (
RowProxy,
)
-from .util import (
- connection_memoize
-)
+from .util import connection_memoize
from . import util, strategies
@@ -92,7 +89,7 @@ from . import util, strategies
# backwards compat
from ..sql import ddl
-default_strategy = 'plain'
+default_strategy = "plain"
def create_engine(*args, **kwargs):
@@ -460,12 +457,12 @@ def create_engine(*args, **kwargs):
"""
- strategy = kwargs.pop('strategy', default_strategy)
+ strategy = kwargs.pop("strategy", default_strategy)
strategy = strategies.strategies[strategy]
return strategy.create(*args, **kwargs)
-def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
+def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
"""Create a new Engine instance using a configuration dictionary.
The dictionary is typically produced from a config file.
@@ -497,16 +494,15 @@ def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
"""
- options = dict((key[len(prefix):], configuration[key])
- for key in configuration
- if key.startswith(prefix))
- options['_coerce_config'] = True
+ options = dict(
+ (key[len(prefix) :], configuration[key])
+ for key in configuration
+ if key.startswith(prefix)
+ )
+ options["_coerce_config"] = True
options.update(kwargs)
- url = options.pop('url')
+ url = options.pop("url")
return create_engine(url, **options)
-__all__ = (
- 'create_engine',
- 'engine_from_config',
-)
+__all__ = ("create_engine", "engine_from_config")
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 4a057ee59..75d03b744 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -61,10 +61,16 @@ class Connection(Connectable):
"""
- def __init__(self, engine, connection=None, close_with_result=False,
- _branch_from=None, _execution_options=None,
- _dispatch=None,
- _has_events=None):
+ def __init__(
+ self,
+ engine,
+ connection=None,
+ close_with_result=False,
+ _branch_from=None,
+ _execution_options=None,
+ _dispatch=None,
+ _has_events=None,
+ ):
"""Construct a new Connection.
The constructor here is not public and is only called only by an
@@ -86,8 +92,11 @@ class Connection(Connectable):
self._has_events = _branch_from._has_events
self.schema_for_object = _branch_from.schema_for_object
else:
- self.__connection = connection \
- if connection is not None else engine.raw_connection()
+ self.__connection = (
+ connection
+ if connection is not None
+ else engine.raw_connection()
+ )
self.__transaction = None
self.__savepoint_seq = 0
self.should_close_with_result = close_with_result
@@ -101,7 +110,8 @@ class Connection(Connectable):
# want to handle any of the engine's events in that case.
self.dispatch = self.dispatch._join(engine.dispatch)
self._has_events = _has_events or (
- _has_events is None and engine._has_events)
+ _has_events is None and engine._has_events
+ )
assert not _execution_options
self._execution_options = engine._execution_options
@@ -134,7 +144,8 @@ class Connection(Connectable):
_branch_from=self,
_execution_options=self._execution_options,
_has_events=self._has_events,
- _dispatch=self.dispatch)
+ _dispatch=self.dispatch,
+ )
@property
def _root(self):
@@ -322,8 +333,10 @@ class Connection(Connectable):
def closed(self):
"""Return True if this connection is closed."""
- return '_Connection__connection' not in self.__dict__ \
+ return (
+ "_Connection__connection" not in self.__dict__
and not self.__can_reconnect
+ )
@property
def invalidated(self):
@@ -425,7 +438,8 @@ class Connection(Connectable):
if self.__transaction is not None:
raise exc.InvalidRequestError(
"Can't reconnect until invalid "
- "transaction is rolled back")
+ "transaction is rolled back"
+ )
self.__connection = self.engine.raw_connection(_connection=self)
self.__invalid = False
return self.__connection
@@ -437,14 +451,15 @@ class Connection(Connectable):
# dialect initializer, where the connection is not wrapped in
# _ConnectionFairy
- return getattr(self.__connection, 'is_valid', False)
+ return getattr(self.__connection, "is_valid", False)
@property
def _still_open_and_connection_is_valid(self):
- return \
- not self.closed and \
- not self.invalidated and \
- getattr(self.__connection, 'is_valid', False)
+ return (
+ not self.closed
+ and not self.invalidated
+ and getattr(self.__connection, "is_valid", False)
+ )
@property
def info(self):
@@ -656,7 +671,8 @@ class Connection(Connectable):
if self.__transaction is not None:
raise exc.InvalidRequestError(
"Cannot start a two phase transaction when a transaction "
- "is already in progress.")
+ "is already in progress."
+ )
if xid is None:
xid = self.engine.dialect.create_xid()
self.__transaction = TwoPhaseTransaction(self, xid)
@@ -705,8 +721,10 @@ class Connection(Connectable):
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
finally:
- if not self.__invalid and \
- self.connection._reset_agent is self.__transaction:
+ if (
+ not self.__invalid
+ and self.connection._reset_agent is self.__transaction
+ ):
self.connection._reset_agent = None
self.__transaction = None
else:
@@ -725,8 +743,10 @@ class Connection(Connectable):
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
finally:
- if not self.__invalid and \
- self.connection._reset_agent is self.__transaction:
+ if (
+ not self.__invalid
+ and self.connection._reset_agent is self.__transaction
+ ):
self.connection._reset_agent = None
self.__transaction = None
@@ -738,7 +758,7 @@ class Connection(Connectable):
if name is None:
self.__savepoint_seq += 1
- name = 'sa_savepoint_%s' % self.__savepoint_seq
+ name = "sa_savepoint_%s" % self.__savepoint_seq
if self._still_open_and_connection_is_valid:
self.engine.dialect.do_savepoint(self, name)
return name
@@ -797,7 +817,8 @@ class Connection(Connectable):
assert isinstance(self.__transaction, TwoPhaseTransaction)
try:
self.engine.dialect.do_rollback_twophase(
- self, xid, is_prepared)
+ self, xid, is_prepared
+ )
finally:
if self.connection._reset_agent is self.__transaction:
self.connection._reset_agent = None
@@ -950,16 +971,16 @@ class Connection(Connectable):
def _execute_function(self, func, multiparams, params):
"""Execute a sql.FunctionElement object."""
- return self._execute_clauseelement(func.select(),
- multiparams, params)
+ return self._execute_clauseelement(func.select(), multiparams, params)
def _execute_default(self, default, multiparams, params):
"""Execute a schema.ColumnDefault object."""
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- default, multiparams, params = \
- fn(self, default, multiparams, params)
+ default, multiparams, params = fn(
+ self, default, multiparams, params
+ )
try:
try:
@@ -972,8 +993,7 @@ class Connection(Connectable):
conn = self._revalidate_connection()
dialect = self.dialect
- ctx = dialect.execution_ctx_cls._init_default(
- dialect, self, conn)
+ ctx = dialect.execution_ctx_cls._init_default(dialect, self, conn)
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
@@ -982,8 +1002,9 @@ class Connection(Connectable):
self.close()
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- default, multiparams, params, ret)
+ self.dispatch.after_execute(
+ self, default, multiparams, params, ret
+ )
return ret
@@ -992,25 +1013,25 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- ddl, multiparams, params = \
- fn(self, ddl, multiparams, params)
+ ddl, multiparams, params = fn(self, ddl, multiparams, params)
dialect = self.dialect
compiled = ddl.compile(
dialect=dialect,
schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default else None)
+ if not self.schema_for_object.is_default
+ else None,
+ )
ret = self._execute_context(
dialect,
dialect.execution_ctx_cls._init_ddl,
compiled,
None,
- compiled
+ compiled,
)
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- ddl, multiparams, params, ret)
+ self.dispatch.after_execute(self, ddl, multiparams, params, ret)
return ret
def _execute_clauseelement(self, elem, multiparams, params):
@@ -1018,8 +1039,7 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- elem, multiparams, params = \
- fn(self, elem, multiparams, params)
+ elem, multiparams, params = fn(self, elem, multiparams, params)
distilled_params = _distill_params(multiparams, params)
if distilled_params:
@@ -1030,38 +1050,45 @@ class Connection(Connectable):
keys = []
dialect = self.dialect
- if 'compiled_cache' in self._execution_options:
+ if "compiled_cache" in self._execution_options:
key = (
- dialect, elem, tuple(sorted(keys)),
+ dialect,
+ elem,
+ tuple(sorted(keys)),
self.schema_for_object.hash_key,
- len(distilled_params) > 1
+ len(distilled_params) > 1,
)
- compiled_sql = self._execution_options['compiled_cache'].get(key)
+ compiled_sql = self._execution_options["compiled_cache"].get(key)
if compiled_sql is None:
compiled_sql = elem.compile(
- dialect=dialect, column_keys=keys,
+ dialect=dialect,
+ column_keys=keys,
inline=len(distilled_params) > 1,
schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default else None
+ if not self.schema_for_object.is_default
+ else None,
)
- self._execution_options['compiled_cache'][key] = compiled_sql
+ self._execution_options["compiled_cache"][key] = compiled_sql
else:
compiled_sql = elem.compile(
- dialect=dialect, column_keys=keys,
+ dialect=dialect,
+ column_keys=keys,
inline=len(distilled_params) > 1,
schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default else None)
+ if not self.schema_for_object.is_default
+ else None,
+ )
ret = self._execute_context(
dialect,
dialect.execution_ctx_cls._init_compiled,
compiled_sql,
distilled_params,
- compiled_sql, distilled_params
+ compiled_sql,
+ distilled_params,
)
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- elem, multiparams, params, ret)
+ self.dispatch.after_execute(self, elem, multiparams, params, ret)
return ret
def _execute_compiled(self, compiled, multiparams, params):
@@ -1069,8 +1096,9 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- compiled, multiparams, params = \
- fn(self, compiled, multiparams, params)
+ compiled, multiparams, params = fn(
+ self, compiled, multiparams, params
+ )
dialect = self.dialect
parameters = _distill_params(multiparams, params)
@@ -1079,11 +1107,13 @@ class Connection(Connectable):
dialect.execution_ctx_cls._init_compiled,
compiled,
parameters,
- compiled, parameters
+ compiled,
+ parameters,
)
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- compiled, multiparams, params, ret)
+ self.dispatch.after_execute(
+ self, compiled, multiparams, params, ret
+ )
return ret
def _execute_text(self, statement, multiparams, params):
@@ -1091,8 +1121,9 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_execute:
- statement, multiparams, params = \
- fn(self, statement, multiparams, params)
+ statement, multiparams, params = fn(
+ self, statement, multiparams, params
+ )
dialect = self.dialect
parameters = _distill_params(multiparams, params)
@@ -1101,16 +1132,18 @@ class Connection(Connectable):
dialect.execution_ctx_cls._init_statement,
statement,
parameters,
- statement, parameters
+ statement,
+ parameters,
)
if self._has_events or self.engine._has_events:
- self.dispatch.after_execute(self,
- statement, multiparams, params, ret)
+ self.dispatch.after_execute(
+ self, statement, multiparams, params, ret
+ )
return ret
- def _execute_context(self, dialect, constructor,
- statement, parameters,
- *args):
+ def _execute_context(
+ self, dialect, constructor, statement, parameters, *args
+ ):
"""Create an :class:`.ExecutionContext` and execute, returning
a :class:`.ResultProxy`."""
@@ -1127,31 +1160,36 @@ class Connection(Connectable):
context = constructor(dialect, self, conn, *args)
except BaseException as e:
self._handle_dbapi_exception(
- e,
- util.text_type(statement), parameters,
- None, None)
+ e, util.text_type(statement), parameters, None, None
+ )
if context.compiled:
context.pre_exec()
- cursor, statement, parameters = context.cursor, \
- context.statement, \
- context.parameters
+ cursor, statement, parameters = (
+ context.cursor,
+ context.statement,
+ context.parameters,
+ )
if not context.executemany:
parameters = parameters[0]
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_cursor_execute:
- statement, parameters = \
- fn(self, cursor, statement, parameters,
- context, context.executemany)
+ statement, parameters = fn(
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ context.executemany,
+ )
if self._echo:
self.engine.logger.info(statement)
self.engine.logger.info(
- "%r",
- sql_util._repr_params(parameters, batches=10)
+ "%r", sql_util._repr_params(parameters, batches=10)
)
evt_handled = False
@@ -1164,10 +1202,8 @@ class Connection(Connectable):
break
if not evt_handled:
self.dialect.do_executemany(
- cursor,
- statement,
- parameters,
- context)
+ cursor, statement, parameters, context
+ )
elif not parameters and context.no_parameters:
if self.dialect._has_events:
for fn in self.dialect.dispatch.do_execute_no_params:
@@ -1176,9 +1212,8 @@ class Connection(Connectable):
break
if not evt_handled:
self.dialect.do_execute_no_params(
- cursor,
- statement,
- context)
+ cursor, statement, context
+ )
else:
if self.dialect._has_events:
for fn in self.dialect.dispatch.do_execute:
@@ -1187,24 +1222,22 @@ class Connection(Connectable):
break
if not evt_handled:
self.dialect.do_execute(
- cursor,
- statement,
- parameters,
- context)
+ cursor, statement, parameters, context
+ )
except BaseException as e:
self._handle_dbapi_exception(
- e,
- statement,
- parameters,
- cursor,
- context)
+ e, statement, parameters, cursor, context
+ )
if self._has_events or self.engine._has_events:
- self.dispatch.after_cursor_execute(self, cursor,
- statement,
- parameters,
- context,
- context.executemany)
+ self.dispatch.after_cursor_execute(
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ context.executemany,
+ )
if context.compiled:
context.post_exec()
@@ -1245,39 +1278,32 @@ class Connection(Connectable):
"""
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_cursor_execute:
- statement, parameters = \
- fn(self, cursor, statement, parameters,
- context,
- False)
+ statement, parameters = fn(
+ self, cursor, statement, parameters, context, False
+ )
if self._echo:
self.engine.logger.info(statement)
self.engine.logger.info("%r", parameters)
try:
- for fn in () if not self.dialect._has_events \
- else self.dialect.dispatch.do_execute:
+ for fn in (
+ ()
+ if not self.dialect._has_events
+ else self.dialect.dispatch.do_execute
+ ):
if fn(cursor, statement, parameters, context):
break
else:
- self.dialect.do_execute(
- cursor,
- statement,
- parameters,
- context)
+ self.dialect.do_execute(cursor, statement, parameters, context)
except BaseException as e:
self._handle_dbapi_exception(
- e,
- statement,
- parameters,
- cursor,
- context)
+ e, statement, parameters, cursor, context
+ )
if self._has_events or self.engine._has_events:
- self.dispatch.after_cursor_execute(self, cursor,
- statement,
- parameters,
- context,
- False)
+ self.dispatch.after_cursor_execute(
+ self, cursor, statement, parameters, context, False
+ )
def _safe_close_cursor(self, cursor):
"""Close the given cursor, catching exceptions
@@ -1289,17 +1315,15 @@ class Connection(Connectable):
except Exception:
# log the error through the connection pool's logger.
self.engine.pool.logger.error(
- "Error closing cursor", exc_info=True)
+ "Error closing cursor", exc_info=True
+ )
_reentrant_error = False
_is_disconnect = False
- def _handle_dbapi_exception(self,
- e,
- statement,
- parameters,
- cursor,
- context):
+ def _handle_dbapi_exception(
+ self, e, statement, parameters, cursor, context
+ ):
exc_info = sys.exc_info()
if context and context.exception is None:
@@ -1309,15 +1333,14 @@ class Connection(Connectable):
if not self._is_disconnect:
self._is_disconnect = (
- isinstance(e, self.dialect.dbapi.Error) and
- not self.closed and
- self.dialect.is_disconnect(
+ isinstance(e, self.dialect.dbapi.Error)
+ and not self.closed
+ and self.dialect.is_disconnect(
e,
self.__connection if not self.invalidated else None,
- cursor)
- ) or (
- is_exit_exception and not self.closed
- )
+ cursor,
+ )
+ ) or (is_exit_exception and not self.closed)
if context:
context.is_disconnect = self._is_disconnect
@@ -1326,20 +1349,24 @@ class Connection(Connectable):
if self._reentrant_error:
util.raise_from_cause(
- exc.DBAPIError.instance(statement,
- parameters,
- e,
- self.dialect.dbapi.Error,
- dialect=self.dialect),
- exc_info
+ exc.DBAPIError.instance(
+ statement,
+ parameters,
+ e,
+ self.dialect.dbapi.Error,
+ dialect=self.dialect,
+ ),
+ exc_info,
)
self._reentrant_error = True
try:
# non-DBAPI error - if we already got a context,
# or there's no string statement, don't wrap it
- should_wrap = isinstance(e, self.dialect.dbapi.Error) or \
- (statement is not None
- and context is None and not is_exit_exception)
+ should_wrap = isinstance(e, self.dialect.dbapi.Error) or (
+ statement is not None
+ and context is None
+ and not is_exit_exception
+ )
if should_wrap:
sqlalchemy_exception = exc.DBAPIError.instance(
@@ -1348,30 +1375,37 @@ class Connection(Connectable):
e,
self.dialect.dbapi.Error,
connection_invalidated=self._is_disconnect,
- dialect=self.dialect)
+ dialect=self.dialect,
+ )
else:
sqlalchemy_exception = None
newraise = None
- if (self._has_events or self.engine._has_events) and \
- not self._execution_options.get(
- 'skip_user_error_events', False):
+ if (
+ self._has_events or self.engine._has_events
+ ) and not self._execution_options.get(
+ "skip_user_error_events", False
+ ):
# legacy dbapi_error event
if should_wrap and context:
- self.dispatch.dbapi_error(self,
- cursor,
- statement,
- parameters,
- context,
- e)
+ self.dispatch.dbapi_error(
+ self, cursor, statement, parameters, context, e
+ )
# new handle_error event
ctx = ExceptionContextImpl(
- e, sqlalchemy_exception, self.engine,
- self, cursor, statement,
- parameters, context, self._is_disconnect,
- invalidate_pool_on_disconnect)
+ e,
+ sqlalchemy_exception,
+ self.engine,
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ self._is_disconnect,
+ invalidate_pool_on_disconnect,
+ )
for fn in self.dispatch.handle_error:
try:
@@ -1388,13 +1422,15 @@ class Connection(Connectable):
if self._is_disconnect != ctx.is_disconnect:
self._is_disconnect = ctx.is_disconnect
if sqlalchemy_exception:
- sqlalchemy_exception.connection_invalidated = \
+ sqlalchemy_exception.connection_invalidated = (
ctx.is_disconnect
+ )
# set up potentially user-defined value for
# invalidate pool.
- invalidate_pool_on_disconnect = \
+ invalidate_pool_on_disconnect = (
ctx.invalidate_pool_on_disconnect
+ )
if should_wrap and context:
context.handle_dbapi_exception(e)
@@ -1408,10 +1444,7 @@ class Connection(Connectable):
if newraise:
util.raise_from_cause(newraise, exc_info)
elif should_wrap:
- util.raise_from_cause(
- sqlalchemy_exception,
- exc_info
- )
+ util.raise_from_cause(sqlalchemy_exception, exc_info)
else:
util.reraise(*exc_info)
@@ -1441,7 +1474,8 @@ class Connection(Connectable):
None,
e,
dialect.dbapi.Error,
- connection_invalidated=is_disconnect)
+ connection_invalidated=is_disconnect,
+ )
else:
sqlalchemy_exception = None
@@ -1449,8 +1483,17 @@ class Connection(Connectable):
if engine._has_events:
ctx = ExceptionContextImpl(
- e, sqlalchemy_exception, engine, None, None, None,
- None, None, is_disconnect, True)
+ e,
+ sqlalchemy_exception,
+ engine,
+ None,
+ None,
+ None,
+ None,
+ None,
+ is_disconnect,
+ True,
+ )
for fn in engine.dispatch.handle_error:
try:
# handler returns an exception;
@@ -1463,18 +1506,15 @@ class Connection(Connectable):
newraise = _raised
break
- if sqlalchemy_exception and \
- is_disconnect != ctx.is_disconnect:
- sqlalchemy_exception.connection_invalidated = \
- is_disconnect = ctx.is_disconnect
+ if sqlalchemy_exception and is_disconnect != ctx.is_disconnect:
+ sqlalchemy_exception.connection_invalidated = (
+ is_disconnect
+ ) = ctx.is_disconnect
if newraise:
util.raise_from_cause(newraise, exc_info)
elif should_wrap:
- util.raise_from_cause(
- sqlalchemy_exception,
- exc_info
- )
+ util.raise_from_cause(sqlalchemy_exception, exc_info)
else:
util.reraise(*exc_info)
@@ -1545,16 +1585,25 @@ class Connection(Connectable):
return callable_(self, *args, **kwargs)
def _run_visitor(self, visitorcallable, element, **kwargs):
- visitorcallable(self.dialect, self,
- **kwargs).traverse_single(element)
+ visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
class ExceptionContextImpl(ExceptionContext):
"""Implement the :class:`.ExceptionContext` interface."""
- def __init__(self, exception, sqlalchemy_exception,
- engine, connection, cursor, statement, parameters,
- context, is_disconnect, invalidate_pool_on_disconnect):
+ def __init__(
+ self,
+ exception,
+ sqlalchemy_exception,
+ engine,
+ connection,
+ cursor,
+ statement,
+ parameters,
+ context,
+ is_disconnect,
+ invalidate_pool_on_disconnect,
+ ):
self.engine = engine
self.connection = connection
self.sqlalchemy_exception = sqlalchemy_exception
@@ -1691,12 +1740,14 @@ class NestedTransaction(Transaction):
def _do_rollback(self):
if self.is_active:
self.connection._rollback_to_savepoint_impl(
- self._savepoint, self._parent)
+ self._savepoint, self._parent
+ )
def _do_commit(self):
if self.is_active:
self.connection._release_savepoint_impl(
- self._savepoint, self._parent)
+ self._savepoint, self._parent
+ )
class TwoPhaseTransaction(Transaction):
@@ -1771,10 +1822,16 @@ class Engine(Connectable, log.Identified):
"""
- def __init__(self, pool, dialect, url,
- logging_name=None, echo=None, proxy=None,
- execution_options=None
- ):
+ def __init__(
+ self,
+ pool,
+ dialect,
+ url,
+ logging_name=None,
+ echo=None,
+ proxy=None,
+ execution_options=None,
+ ):
self.pool = pool
self.url = url
self.dialect = dialect
@@ -1805,8 +1862,7 @@ class Engine(Connectable, log.Identified):
:meth:`.Engine.execution_options`
"""
- self._execution_options = \
- self._execution_options.union(opt)
+ self._execution_options = self._execution_options.union(opt)
self.dispatch.set_engine_execution_options(self, opt)
self.dialect.set_engine_execution_options(self, opt)
@@ -1894,7 +1950,7 @@ class Engine(Connectable, log.Identified):
echo = log.echo_property()
def __repr__(self):
- return 'Engine(%r)' % self.url
+ return "Engine(%r)" % self.url
def dispose(self):
"""Dispose of the connection pool used by this :class:`.Engine`.
@@ -1934,8 +1990,9 @@ class Engine(Connectable, log.Identified):
else:
yield connection
- def _run_visitor(self, visitorcallable, element,
- connection=None, **kwargs):
+ def _run_visitor(
+ self, visitorcallable, element, connection=None, **kwargs
+ ):
with self._optional_conn_ctx_manager(connection) as conn:
conn._run_visitor(visitorcallable, element, **kwargs)
@@ -2122,7 +2179,8 @@ class Engine(Connectable, log.Identified):
self,
self._wrap_pool_connect(self.pool.connect, None),
close_with_result=close_with_result,
- **kwargs)
+ **kwargs
+ )
def table_names(self, schema=None, connection=None):
"""Return a list of all table names available in the database.
@@ -2159,7 +2217,8 @@ class Engine(Connectable, log.Identified):
except dialect.dbapi.Error as e:
if connection is None:
Connection._handle_dbapi_exception_noconnection(
- e, dialect, self)
+ e, dialect, self
+ )
else:
util.reraise(*sys.exc_info())
@@ -2185,7 +2244,8 @@ class Engine(Connectable, log.Identified):
"""
return self._wrap_pool_connect(
- self.pool.unique_connection, _connection)
+ self.pool.unique_connection, _connection
+ )
class OptionEngine(Engine):
@@ -2225,10 +2285,11 @@ class OptionEngine(Engine):
pool = property(_get_pool, _set_pool)
def _get_has_events(self):
- return self._proxied._has_events or \
- self.__dict__.get('_has_events', False)
+ return self._proxied._has_events or self.__dict__.get(
+ "_has_events", False
+ )
def _set_has_events(self, value):
- self.__dict__['_has_events'] = value
+ self.__dict__["_has_events"] = value
_has_events = property(_get_has_events, _set_has_events)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 028abc4c2..d7c2518fe 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -24,13 +24,11 @@ import weakref
from .. import event
AUTOCOMMIT_REGEXP = re.compile(
- r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
- re.I | re.UNICODE)
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE
+)
# When we're handed literal SQL, ensure it's a SELECT query
-SERVER_SIDE_CURSOR_RE = re.compile(
- r'\s*SELECT',
- re.I | re.UNICODE)
+SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
class DefaultDialect(interfaces.Dialect):
@@ -68,16 +66,18 @@ class DefaultDialect(interfaces.Dialect):
supports_simple_order_by_label = True
- engine_config_types = util.immutabledict([
- ('convert_unicode', util.bool_or_str('force')),
- ('pool_timeout', util.asint),
- ('echo', util.bool_or_str('debug')),
- ('echo_pool', util.bool_or_str('debug')),
- ('pool_recycle', util.asint),
- ('pool_size', util.asint),
- ('max_overflow', util.asint),
- ('pool_threadlocal', util.asbool),
- ])
+ engine_config_types = util.immutabledict(
+ [
+ ("convert_unicode", util.bool_or_str("force")),
+ ("pool_timeout", util.asint),
+ ("echo", util.bool_or_str("debug")),
+ ("echo_pool", util.bool_or_str("debug")),
+ ("pool_recycle", util.asint),
+ ("pool_size", util.asint),
+ ("max_overflow", util.asint),
+ ("pool_threadlocal", util.asbool),
+ ]
+ )
# if the NUMERIC type
# returns decimal.Decimal.
@@ -93,9 +93,9 @@ class DefaultDialect(interfaces.Dialect):
supports_unicode_statements = False
supports_unicode_binds = False
returns_unicode_strings = False
- description_encoding = 'use_encoding'
+ description_encoding = "use_encoding"
- name = 'default'
+ name = "default"
# length at which to truncate
# any identifier.
@@ -111,7 +111,7 @@ class DefaultDialect(interfaces.Dialect):
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
colspecs = {}
- default_paramstyle = 'named'
+ default_paramstyle = "named"
supports_default_values = False
supports_empty_insert = True
supports_multivalues_insert = False
@@ -175,19 +175,26 @@ class DefaultDialect(interfaces.Dialect):
"""
- def __init__(self, convert_unicode=False,
- encoding='utf-8', paramstyle=None, dbapi=None,
- implicit_returning=None,
- supports_right_nested_joins=None,
- case_sensitive=True,
- supports_native_boolean=None,
- empty_in_strategy='static',
- label_length=None, **kwargs):
-
- if not getattr(self, 'ported_sqla_06', True):
+ def __init__(
+ self,
+ convert_unicode=False,
+ encoding="utf-8",
+ paramstyle=None,
+ dbapi=None,
+ implicit_returning=None,
+ supports_right_nested_joins=None,
+ case_sensitive=True,
+ supports_native_boolean=None,
+ empty_in_strategy="static",
+ label_length=None,
+ **kwargs
+ ):
+
+ if not getattr(self, "ported_sqla_06", True):
util.warn(
- "The %s dialect is not yet ported to the 0.6 format" %
- self.name)
+ "The %s dialect is not yet ported to the 0.6 format"
+ % self.name
+ )
self.convert_unicode = convert_unicode
self.encoding = encoding
@@ -202,7 +209,7 @@ class DefaultDialect(interfaces.Dialect):
self.paramstyle = self.default_paramstyle
if implicit_returning is not None:
self.implicit_returning = implicit_returning
- self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
+ self.positional = self.paramstyle in ("qmark", "format", "numeric")
self.identifier_preparer = self.preparer(self)
self.type_compiler = self.type_compiler(self)
if supports_right_nested_joins is not None:
@@ -212,33 +219,33 @@ class DefaultDialect(interfaces.Dialect):
self.case_sensitive = case_sensitive
self.empty_in_strategy = empty_in_strategy
- if empty_in_strategy == 'static':
+ if empty_in_strategy == "static":
self._use_static_in = True
- elif empty_in_strategy in ('dynamic', 'dynamic_warn'):
+ elif empty_in_strategy in ("dynamic", "dynamic_warn"):
self._use_static_in = False
- self._warn_on_empty_in = empty_in_strategy == 'dynamic_warn'
+ self._warn_on_empty_in = empty_in_strategy == "dynamic_warn"
else:
raise exc.ArgumentError(
"empty_in_strategy may be 'static', "
- "'dynamic', or 'dynamic_warn'")
+ "'dynamic', or 'dynamic_warn'"
+ )
if label_length and label_length > self.max_identifier_length:
raise exc.ArgumentError(
"Label length of %d is greater than this dialect's"
- " maximum identifier length of %d" %
- (label_length, self.max_identifier_length))
+ " maximum identifier length of %d"
+ % (label_length, self.max_identifier_length)
+ )
self.label_length = label_length
- if self.description_encoding == 'use_encoding':
- self._description_decoder = \
- processors.to_unicode_processor_factory(
- encoding
- )
+ if self.description_encoding == "use_encoding":
+ self._description_decoder = processors.to_unicode_processor_factory(
+ encoding
+ )
elif self.description_encoding is not None:
- self._description_decoder = \
- processors.to_unicode_processor_factory(
- self.description_encoding
- )
+ self._description_decoder = processors.to_unicode_processor_factory(
+ self.description_encoding
+ )
self._encoder = codecs.getencoder(self.encoding)
self._decoder = processors.to_unicode_processor_factory(self.encoding)
@@ -256,30 +263,35 @@ class DefaultDialect(interfaces.Dialect):
@classmethod
def get_pool_class(cls, url):
- return getattr(cls, 'poolclass', pool.QueuePool)
+ return getattr(cls, "poolclass", pool.QueuePool)
def initialize(self, connection):
try:
- self.server_version_info = \
- self._get_server_version_info(connection)
+ self.server_version_info = self._get_server_version_info(
+ connection
+ )
except NotImplementedError:
self.server_version_info = None
try:
- self.default_schema_name = \
- self._get_default_schema_name(connection)
+ self.default_schema_name = self._get_default_schema_name(
+ connection
+ )
except NotImplementedError:
self.default_schema_name = None
try:
- self.default_isolation_level = \
- self.get_isolation_level(connection.connection)
+ self.default_isolation_level = self.get_isolation_level(
+ connection.connection
+ )
except NotImplementedError:
self.default_isolation_level = None
self.returns_unicode_strings = self._check_unicode_returns(connection)
- if self.description_encoding is not None and \
- self._check_unicode_description(connection):
+ if (
+ self.description_encoding is not None
+ and self._check_unicode_description(connection)
+ ):
self._description_decoder = self.description_encoding = None
self.do_rollback(connection.connection)
@@ -311,7 +323,8 @@ class DefaultDialect(interfaces.Dialect):
def check_unicode(test):
statement = cast_to(
- expression.select([test]).compile(dialect=self))
+ expression.select([test]).compile(dialect=self)
+ )
try:
cursor = connection.connection.cursor()
connection._cursor_execute(cursor, statement, parameters)
@@ -320,8 +333,10 @@ class DefaultDialect(interfaces.Dialect):
except exc.DBAPIError as de:
# note that _cursor_execute() will have closed the cursor
# if an exception is thrown.
- util.warn("Exception attempting to "
- "detect unicode returns: %r" % de)
+ util.warn(
+ "Exception attempting to "
+ "detect unicode returns: %r" % de
+ )
return False
else:
return isinstance(row[0], util.text_type)
@@ -330,13 +345,13 @@ class DefaultDialect(interfaces.Dialect):
# detect plain VARCHAR
expression.cast(
expression.literal_column("'test plain returns'"),
- sqltypes.VARCHAR(60)
+ sqltypes.VARCHAR(60),
),
# detect if there's an NVARCHAR type with different behavior
# available
expression.cast(
expression.literal_column("'test unicode returns'"),
- sqltypes.Unicode(60)
+ sqltypes.Unicode(60),
),
]
@@ -364,9 +379,9 @@ class DefaultDialect(interfaces.Dialect):
try:
cursor.execute(
cast_to(
- expression.select([
- expression.literal_column("'x'").label("some_label")
- ]).compile(dialect=self)
+ expression.select(
+ [expression.literal_column("'x'").label("some_label")]
+ ).compile(dialect=self)
)
)
return isinstance(cursor.description[0][0], util.text_type)
@@ -385,10 +400,12 @@ class DefaultDialect(interfaces.Dialect):
return sqltypes.adapt_type(typeobj, self.colspecs)
def reflecttable(
- self, connection, table, include_columns, exclude_columns, **opts):
+ self, connection, table, include_columns, exclude_columns, **opts
+ ):
insp = reflection.Inspector.from_engine(connection)
return insp.reflecttable(
- table, include_columns, exclude_columns, **opts)
+ table, include_columns, exclude_columns, **opts
+ )
def get_pk_constraint(self, conn, table_name, schema=None, **kw):
"""Compatibility method, adapts the result of get_primary_keys()
@@ -396,16 +413,16 @@ class DefaultDialect(interfaces.Dialect):
"""
return {
- 'constrained_columns':
- self.get_primary_keys(conn, table_name,
- schema=schema, **kw)
+ "constrained_columns": self.get_primary_keys(
+ conn, table_name, schema=schema, **kw
+ )
}
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
- "Identifier '%s' exceeds maximum length of %d characters" %
- (ident, self.max_identifier_length)
+ "Identifier '%s' exceeds maximum length of %d characters"
+ % (ident, self.max_identifier_length)
)
def connect(self, *cargs, **cparams):
@@ -417,16 +434,16 @@ class DefaultDialect(interfaces.Dialect):
return [[], opts]
def set_engine_execution_options(self, engine, opts):
- if 'isolation_level' in opts:
- isolation_level = opts['isolation_level']
+ if "isolation_level" in opts:
+ isolation_level = opts["isolation_level"]
@event.listens_for(engine, "engine_connect")
def set_isolation(connection, branch):
if not branch:
self._set_connection_isolation(connection, isolation_level)
- if 'schema_translate_map' in opts:
- getter = schema._schema_getter(opts['schema_translate_map'])
+ if "schema_translate_map" in opts:
+ getter = schema._schema_getter(opts["schema_translate_map"])
engine.schema_for_object = getter
@event.listens_for(engine, "engine_connect")
@@ -434,11 +451,11 @@ class DefaultDialect(interfaces.Dialect):
connection.schema_for_object = getter
def set_connection_execution_options(self, connection, opts):
- if 'isolation_level' in opts:
- self._set_connection_isolation(connection, opts['isolation_level'])
+ if "isolation_level" in opts:
+ self._set_connection_isolation(connection, opts["isolation_level"])
- if 'schema_translate_map' in opts:
- getter = schema._schema_getter(opts['schema_translate_map'])
+ if "schema_translate_map" in opts:
+ getter = schema._schema_getter(opts["schema_translate_map"])
connection.schema_for_object = getter
def _set_connection_isolation(self, connection, level):
@@ -447,10 +464,12 @@ class DefaultDialect(interfaces.Dialect):
"Connection is already established with a Transaction; "
"setting isolation_level may implicitly rollback or commit "
"the existing transaction, or have no effect until "
- "next transaction")
+ "next transaction"
+ )
self.set_isolation_level(connection.connection, level)
- connection.connection._connection_record.\
- finalize_callback.append(self.reset_isolation_level)
+ connection.connection._connection_record.finalize_callback.append(
+ self.reset_isolation_level
+ )
def do_begin(self, dbapi_connection):
pass
@@ -593,8 +612,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return self
@classmethod
- def _init_compiled(cls, dialect, connection, dbapi_connection,
- compiled, parameters):
+ def _init_compiled(
+ cls, dialect, connection, dbapi_connection, compiled, parameters
+ ):
"""Initialize execution context for a Compiled construct."""
self = cls.__new__(cls)
@@ -609,16 +629,20 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
assert compiled.can_execute
self.execution_options = compiled.execution_options.union(
- connection._execution_options)
+ connection._execution_options
+ )
self.result_column_struct = (
- compiled._result_columns, compiled._ordered_columns,
- compiled._textual_ordered_columns)
+ compiled._result_columns,
+ compiled._ordered_columns,
+ compiled._textual_ordered_columns,
+ )
self.unicode_statement = util.text_type(compiled)
if not dialect.supports_unicode_statements:
self.statement = self.unicode_statement.encode(
- self.dialect.encoding)
+ self.dialect.encoding
+ )
else:
self.statement = self.unicode_statement
@@ -630,9 +654,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if not parameters:
self.compiled_parameters = [compiled.construct_params()]
else:
- self.compiled_parameters = \
- [compiled.construct_params(m, _group_number=grp) for
- grp, m in enumerate(parameters)]
+ self.compiled_parameters = [
+ compiled.construct_params(m, _group_number=grp)
+ for grp, m in enumerate(parameters)
+ ]
self.executemany = len(parameters) > 1
@@ -642,7 +667,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.is_crud = True
self._is_explicit_returning = bool(compiled.statement._returning)
self._is_implicit_returning = bool(
- compiled.returning and not compiled.statement._returning)
+ compiled.returning and not compiled.statement._returning
+ )
if self.compiled.insert_prefetch or self.compiled.update_prefetch:
if self.executemany:
@@ -680,7 +706,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
dialect._encoder(key)[0],
processors[key](compiled_params[key])
if key in processors
- else compiled_params[key]
+ else compiled_params[key],
)
for key in compiled_params
)
@@ -690,7 +716,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
key,
processors[key](compiled_params[key])
if key in processors
- else compiled_params[key]
+ else compiled_params[key],
)
for key in compiled_params
)
@@ -708,14 +734,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
if self.executemany:
raise exc.InvalidRequestError(
- "'expanding' parameters can't be used with "
- "executemany()")
+ "'expanding' parameters can't be used with " "executemany()"
+ )
if self.compiled.positional and self.compiled._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'
raise NotImplementedError(
"'expanding' bind parameters not supported with "
- "'numeric' paramstyle at this time.")
+ "'numeric' paramstyle at this time."
+ )
self._expanded_parameters = {}
@@ -729,7 +756,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
to_update_sets = {}
for name in (
- self.compiled.positiontup if compiled.positional
+ self.compiled.positiontup
+ if compiled.positional
else self.compiled.binds
):
parameter = self.compiled.binds[name]
@@ -748,12 +776,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if not values:
to_update = to_update_sets[name] = []
- replacement_expressions[name] = (
- self.compiled.visit_empty_set_expr(
- parameter._expanding_in_types
- if parameter._expanding_in_types
- else [parameter.type]
- )
+ replacement_expressions[
+ name
+ ] = self.compiled.visit_empty_set_expr(
+ parameter._expanding_in_types
+ if parameter._expanding_in_types
+ else [parameter.type]
)
elif isinstance(values[0], (tuple, list)):
@@ -763,15 +791,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for j, value in enumerate(tuple_element, 1)
]
replacement_expressions[name] = ", ".join(
- "(%s)" % ", ".join(
- self.compiled.bindtemplate % {
- "name":
- to_update[i * len(tuple_element) + j][0]
+ "(%s)"
+ % ", ".join(
+ self.compiled.bindtemplate
+ % {
+ "name": to_update[
+ i * len(tuple_element) + j
+ ][0]
}
for j, value in enumerate(tuple_element)
)
for i, tuple_element in enumerate(values)
-
)
else:
to_update = to_update_sets[name] = [
@@ -779,20 +809,21 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for i, value in enumerate(values, 1)
]
replacement_expressions[name] = ", ".join(
- self.compiled.bindtemplate % {
- "name": key}
+ self.compiled.bindtemplate % {"name": key}
for key, value in to_update
)
compiled_params.update(to_update)
processors.update(
(key, processors[name])
- for key, value in to_update if name in processors
+ for key, value in to_update
+ if name in processors
)
if compiled.positional:
positiontup.extend(name for name, value in to_update)
self._expanded_parameters[name] = [
- expand_key for expand_key, value in to_update]
+ expand_key for expand_key, value in to_update
+ ]
elif compiled.positional:
positiontup.append(name)
@@ -800,15 +831,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return replacement_expressions[m.group(1)]
self.statement = re.sub(
- r"\[EXPANDING_(\S+)\]",
- process_expanding,
- self.statement
+ r"\[EXPANDING_(\S+)\]", process_expanding, self.statement
)
return positiontup
@classmethod
- def _init_statement(cls, dialect, connection, dbapi_connection,
- statement, parameters):
+ def _init_statement(
+ cls, dialect, connection, dbapi_connection, statement, parameters
+ ):
"""Initialize execution context for a string SQL statement."""
self = cls.__new__(cls)
@@ -836,13 +866,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
for d in parameters
] or [{}]
else:
- self.parameters = [dialect.execute_sequence_format(p)
- for p in parameters]
+ self.parameters = [
+ dialect.execute_sequence_format(p) for p in parameters
+ ]
self.executemany = len(parameters) > 1
- if not dialect.supports_unicode_statements and \
- isinstance(statement, util.text_type):
+ if not dialect.supports_unicode_statements and isinstance(
+ statement, util.text_type
+ ):
self.unicode_statement = statement
self.statement = dialect._encoder(statement)[0]
else:
@@ -890,11 +922,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
@util.memoized_property
def should_autocommit(self):
- autocommit = self.execution_options.get('autocommit',
- not self.compiled and
- self.statement and
- expression.PARSE_AUTOCOMMIT
- or False)
+ autocommit = self.execution_options.get(
+ "autocommit",
+ not self.compiled
+ and self.statement
+ and expression.PARSE_AUTOCOMMIT
+ or False,
+ )
if autocommit is expression.PARSE_AUTOCOMMIT:
return self.should_autocommit_text(self.unicode_statement)
@@ -912,8 +946,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
conn = self.root_connection
- if isinstance(stmt, util.text_type) and \
- not self.dialect.supports_unicode_statements:
+ if (
+ isinstance(stmt, util.text_type)
+ and not self.dialect.supports_unicode_statements
+ ):
stmt = self.dialect._encoder(stmt)[0]
if self.dialect.positional:
@@ -926,8 +962,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if type_ is not None:
# apply type post processors to the result
proc = type_._cached_result_processor(
- self.dialect,
- self.cursor.description[0][1]
+ self.dialect, self.cursor.description[0][1]
)
if proc:
return proc(r)
@@ -945,22 +980,30 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return False
if self.dialect.server_side_cursors:
- use_server_side = \
- self.execution_options.get('stream_results', True) and (
- (self.compiled and isinstance(self.compiled.statement,
- expression.Selectable)
- or
- (
- (not self.compiled or
- isinstance(self.compiled.statement,
- expression.TextClause))
- and self.statement and SERVER_SIDE_CURSOR_RE.match(
- self.statement))
- )
+ use_server_side = self.execution_options.get(
+ "stream_results", True
+ ) and (
+ (
+ self.compiled
+ and isinstance(
+ self.compiled.statement, expression.Selectable
+ )
+ or (
+ (
+ not self.compiled
+ or isinstance(
+ self.compiled.statement, expression.TextClause
+ )
+ )
+ and self.statement
+ and SERVER_SIDE_CURSOR_RE.match(self.statement)
+ )
)
+ )
else:
- use_server_side = \
- self.execution_options.get('stream_results', False)
+ use_server_side = self.execution_options.get(
+ "stream_results", False
+ )
return use_server_side
@@ -1039,11 +1082,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return self.dialect.supports_sane_multi_rowcount
def _setup_crud_result_proxy(self):
- if self.isinsert and \
- not self.executemany:
- if not self._is_implicit_returning and \
- not self.compiled.inline and \
- self.dialect.postfetch_lastrowid:
+ if self.isinsert and not self.executemany:
+ if (
+ not self._is_implicit_returning
+ and not self.compiled.inline
+ and self.dialect.postfetch_lastrowid
+ ):
self._setup_ins_pk_from_lastrowid()
@@ -1087,12 +1131,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if autoinc_col is not None:
# apply type post processors to the lastrowid
proc = autoinc_col.type._cached_result_processor(
- self.dialect, None)
+ self.dialect, None
+ )
if proc is not None:
lastrowid = proc(lastrowid)
self.inserted_primary_key = [
- lastrowid if c is autoinc_col else
- compiled_params.get(key_getter(c), None)
+ lastrowid
+ if c is autoinc_col
+ else compiled_params.get(key_getter(c), None)
for c in table.primary_key
]
else:
@@ -1108,8 +1154,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
table = self.compiled.statement.table
compiled_params = self.compiled_parameters[0]
self.inserted_primary_key = [
- compiled_params.get(key_getter(c), None)
- for c in table.primary_key
+ compiled_params.get(key_getter(c), None) for c in table.primary_key
]
def _setup_ins_pk_from_implicit_returning(self, row):
@@ -1129,11 +1174,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
]
def lastrow_has_defaults(self):
- return (self.isinsert or self.isupdate) and \
- bool(self.compiled.postfetch)
+ return (self.isinsert or self.isupdate) and bool(
+ self.compiled.postfetch
+ )
def set_input_sizes(
- self, translate=None, include_types=None, exclude_types=None):
+ self, translate=None, include_types=None, exclude_types=None
+ ):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
@@ -1143,7 +1190,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"""
- if not hasattr(self.compiled, 'bind_names'):
+ if not hasattr(self.compiled, "bind_names"):
return
inputsizes = {}
@@ -1153,12 +1200,18 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
dialect_impl_cls = type(dialect_impl)
dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None and (
- not exclude_types or dbtype not in exclude_types and
- dialect_impl_cls not in exclude_types
- ) and (
- not include_types or dbtype in include_types or
- dialect_impl_cls in include_types
+ if (
+ dbtype is not None
+ and (
+ not exclude_types
+ or dbtype not in exclude_types
+ and dialect_impl_cls not in exclude_types
+ )
+ and (
+ not include_types
+ or dbtype in include_types
+ or dialect_impl_cls in include_types
+ )
):
inputsizes[bindparam] = dbtype
else:
@@ -1177,14 +1230,16 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if dbtype is not None:
if key in self._expanded_parameters:
positional_inputsizes.extend(
- [dbtype] * len(self._expanded_parameters[key]))
+ [dbtype] * len(self._expanded_parameters[key])
+ )
else:
positional_inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*positional_inputsizes)
except BaseException as e:
self.root_connection._handle_dbapi_exception(
- e, None, None, None, self)
+ e, None, None, None, self
+ )
else:
keyword_inputsizes = {}
for bindparam, key in self.compiled.bind_names.items():
@@ -1199,8 +1254,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
key = self.dialect._encoder(key)[0]
if key in self._expanded_parameters:
keyword_inputsizes.update(
- (expand_key, dbtype) for expand_key
- in self._expanded_parameters[key]
+ (expand_key, dbtype)
+ for expand_key in self._expanded_parameters[key]
)
else:
keyword_inputsizes[key] = dbtype
@@ -1208,7 +1263,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.cursor.setinputsizes(**keyword_inputsizes)
except BaseException as e:
self.root_connection._handle_dbapi_exception(
- e, None, None, None, self)
+ e, None, None, None, self
+ )
def _exec_default(self, column, default, type_):
if default.is_sequence:
@@ -1290,10 +1346,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
except AttributeError:
raise exc.InvalidRequestError(
"get_current_parameters() can only be invoked in the "
- "context of a Python side column default function")
- if isolate_multiinsert_groups and \
- self.isinsert and \
- self.compiled.statement._has_multi_parameters:
+ "context of a Python side column default function"
+ )
+ if (
+ isolate_multiinsert_groups
+ and self.isinsert
+ and self.compiled.statement._has_multi_parameters
+ ):
if column._is_multiparam_column:
index = column.index + 1
d = {column.original.key: parameters[column.key]}
@@ -1302,8 +1361,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
index = 0
keys = self.compiled.statement.parameters[0].keys()
d.update(
- (key, parameters["%s_m%d" % (key, index)])
- for key in keys
+ (key, parameters["%s_m%d" % (key, index)]) for key in keys
)
return d
else:
@@ -1360,12 +1418,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
def _process_executesingle_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]
- self.current_parameters = compiled_parameters = \
- self.compiled_parameters[0]
+ self.current_parameters = (
+ compiled_parameters
+ ) = self.compiled_parameters[0]
for c in self.compiled.insert_prefetch:
- if c.default and \
- not c.default.is_sequence and c.default.is_scalar:
+ if c.default and not c.default.is_sequence and c.default.is_scalar:
val = c.default.arg
else:
val = self.get_insert_default(c)
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 9c3b24e9a..e10e6e884 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -198,7 +198,8 @@ class Dialect(object):
pass
def reflecttable(
- self, connection, table, include_columns, exclude_columns):
+ self, connection, table, include_columns, exclude_columns
+ ):
"""Load table description from the database.
Given a :class:`.Connection` and a
@@ -367,7 +368,8 @@ class Dialect(object):
raise NotImplementedError()
def get_unique_constraints(
- self, connection, table_name, schema=None, **kw):
+ self, connection, table_name, schema=None, **kw
+ ):
r"""Return information about unique constraints in `table_name`.
Given a string `table_name` and an optional string `schema`, return
@@ -389,8 +391,7 @@ class Dialect(object):
raise NotImplementedError()
- def get_check_constraints(
- self, connection, table_name, schema=None, **kw):
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
r"""Return information about check constraints in `table_name`.
Given a string `table_name` and an optional string `schema`, return
@@ -412,8 +413,7 @@ class Dialect(object):
raise NotImplementedError()
- def get_table_comment(
- self, connection, table_name, schema=None, **kw):
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
r"""Return the "comment" for the table identified by `table_name`.
Given a string `table_name` and an optional string `schema`, return
@@ -613,8 +613,9 @@ class Dialect(object):
raise NotImplementedError()
- def do_rollback_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
"""Rollback a two phase transaction on the given connection.
:param connection: a :class:`.Connection`.
@@ -627,8 +628,9 @@ class Dialect(object):
raise NotImplementedError()
- def do_commit_twophase(self, connection, xid, is_prepared=True,
- recover=False):
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
"""Commit a two phase transaction on the given connection.
@@ -664,8 +666,9 @@ class Dialect(object):
raise NotImplementedError()
- def do_execute_no_params(self, cursor, statement, parameters,
- context=None):
+ def do_execute_no_params(
+ self, cursor, statement, parameters, context=None
+ ):
"""Provide an implementation of ``cursor.execute(statement)``.
The parameter collection should not be sent.
@@ -899,6 +902,7 @@ class CreateEnginePlugin(object):
.. versionadded:: 1.1
"""
+
def __init__(self, url, kwargs):
"""Contruct a new :class:`.CreateEnginePlugin`.
@@ -1129,20 +1133,24 @@ class Connectable(object):
raise NotImplementedError()
- @util.deprecated("0.7",
- "Use the create() method on the given schema "
- "object directly, i.e. :meth:`.Table.create`, "
- ":meth:`.Index.create`, :meth:`.MetaData.create_all`")
+ @util.deprecated(
+ "0.7",
+ "Use the create() method on the given schema "
+ "object directly, i.e. :meth:`.Table.create`, "
+ ":meth:`.Index.create`, :meth:`.MetaData.create_all`",
+ )
def create(self, entity, **kwargs):
"""Emit CREATE statements for the given schema entity.
"""
raise NotImplementedError()
- @util.deprecated("0.7",
- "Use the drop() method on the given schema "
- "object directly, i.e. :meth:`.Table.drop`, "
- ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`")
+ @util.deprecated(
+ "0.7",
+ "Use the drop() method on the given schema "
+ "object directly, i.e. :meth:`.Table.drop`, "
+ ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`",
+ )
def drop(self, entity, **kwargs):
"""Emit DROP statements for the given schema entity.
"""
@@ -1160,8 +1168,7 @@ class Connectable(object):
"""
raise NotImplementedError()
- def _run_visitor(self, visitorcallable, element,
- **kwargs):
+ def _run_visitor(self, visitorcallable, element, **kwargs):
raise NotImplementedError()
def _execute_clauseelement(self, elem, multiparams=None, params=None):
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
index 841bb4dfb..9b5fa2459 100644
--- a/lib/sqlalchemy/engine/reflection.py
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -37,17 +37,17 @@ from .base import Connectable
@util.decorator
def cache(fn, self, con, *args, **kw):
- info_cache = kw.get('info_cache', None)
+ info_cache = kw.get("info_cache", None)
if info_cache is None:
return fn(self, con, *args, **kw)
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, util.string_types)),
- tuple((k, v) for k, v in kw.items() if
- isinstance(v,
- util.string_types + util.int_types + (float, )
- )
- )
+ tuple(
+ (k, v)
+ for k, v in kw.items()
+ if isinstance(v, util.string_types + util.int_types + (float,))
+ ),
)
ret = info_cache.get(key)
if ret is None:
@@ -99,7 +99,7 @@ class Inspector(object):
self.bind = bind
# set the engine
- if hasattr(bind, 'engine'):
+ if hasattr(bind, "engine"):
self.engine = bind.engine
else:
self.engine = bind
@@ -130,7 +130,7 @@ class Inspector(object):
See the example at :class:`.Inspector`.
"""
- if hasattr(bind.dialect, 'inspector'):
+ if hasattr(bind.dialect, "inspector"):
return bind.dialect.inspector(bind)
return Inspector(bind)
@@ -153,9 +153,10 @@ class Inspector(object):
"""Return all schema names.
"""
- if hasattr(self.dialect, 'get_schema_names'):
- return self.dialect.get_schema_names(self.bind,
- info_cache=self.info_cache)
+ if hasattr(self.dialect, "get_schema_names"):
+ return self.dialect.get_schema_names(
+ self.bind, info_cache=self.info_cache
+ )
return []
def get_table_names(self, schema=None, order_by=None):
@@ -196,17 +197,18 @@ class Inspector(object):
"""
- if hasattr(self.dialect, 'get_table_names'):
+ if hasattr(self.dialect, "get_table_names"):
tnames = self.dialect.get_table_names(
- self.bind, schema, info_cache=self.info_cache)
+ self.bind, schema, info_cache=self.info_cache
+ )
else:
tnames = self.engine.table_names(schema)
- if order_by == 'foreign_key':
+ if order_by == "foreign_key":
tuples = []
for tname in tnames:
for fkey in self.get_foreign_keys(tname, schema):
- if tname != fkey['referred_table']:
- tuples.append((fkey['referred_table'], tname))
+ if tname != fkey["referred_table"]:
+ tuples.append((fkey["referred_table"], tname))
tnames = list(topological.sort(tuples, tnames))
return tnames
@@ -234,9 +236,10 @@ class Inspector(object):
with an already-given :class:`.MetaData`.
"""
- if hasattr(self.dialect, 'get_table_names'):
+ if hasattr(self.dialect, "get_table_names"):
tnames = self.dialect.get_table_names(
- self.bind, schema, info_cache=self.info_cache)
+ self.bind, schema, info_cache=self.info_cache
+ )
else:
tnames = self.engine.table_names(schema)
@@ -246,20 +249,17 @@ class Inspector(object):
fknames_for_table = {}
for tname in tnames:
fkeys = self.get_foreign_keys(tname, schema)
- fknames_for_table[tname] = set(
- [fk['name'] for fk in fkeys]
- )
+ fknames_for_table[tname] = set([fk["name"] for fk in fkeys])
for fkey in fkeys:
- if tname != fkey['referred_table']:
- tuples.add((fkey['referred_table'], tname))
+ if tname != fkey["referred_table"]:
+ tuples.add((fkey["referred_table"], tname))
try:
candidate_sort = list(topological.sort(tuples, tnames))
except exc.CircularDependencyError as err:
for edge in err.edges:
tuples.remove(edge)
remaining_fkcs.update(
- (edge[1], fkc)
- for fkc in fknames_for_table[edge[1]]
+ (edge[1], fkc) for fkc in fknames_for_table[edge[1]]
)
candidate_sort = list(topological.sort(tuples, tnames))
@@ -278,7 +278,8 @@ class Inspector(object):
"""
return self.dialect.get_temp_table_names(
- self.bind, info_cache=self.info_cache)
+ self.bind, info_cache=self.info_cache
+ )
def get_temp_view_names(self):
"""return a list of temporary view names for the current bind.
@@ -290,7 +291,8 @@ class Inspector(object):
"""
return self.dialect.get_temp_view_names(
- self.bind, info_cache=self.info_cache)
+ self.bind, info_cache=self.info_cache
+ )
def get_table_options(self, table_name, schema=None, **kw):
"""Return a dictionary of options specified when the table of the
@@ -306,10 +308,10 @@ class Inspector(object):
use :class:`.quoted_name`.
"""
- if hasattr(self.dialect, 'get_table_options'):
+ if hasattr(self.dialect, "get_table_options"):
return self.dialect.get_table_options(
- self.bind, table_name, schema,
- info_cache=self.info_cache, **kw)
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
return {}
def get_view_names(self, schema=None):
@@ -320,8 +322,9 @@ class Inspector(object):
"""
- return self.dialect.get_view_names(self.bind, schema,
- info_cache=self.info_cache)
+ return self.dialect.get_view_names(
+ self.bind, schema, info_cache=self.info_cache
+ )
def get_view_definition(self, view_name, schema=None):
"""Return definition for `view_name`.
@@ -332,7 +335,8 @@ class Inspector(object):
"""
return self.dialect.get_view_definition(
- self.bind, view_name, schema, info_cache=self.info_cache)
+ self.bind, view_name, schema, info_cache=self.info_cache
+ )
def get_columns(self, table_name, schema=None, **kw):
"""Return information about columns in `table_name`.
@@ -364,18 +368,21 @@ class Inspector(object):
"""
- col_defs = self.dialect.get_columns(self.bind, table_name, schema,
- info_cache=self.info_cache,
- **kw)
+ col_defs = self.dialect.get_columns(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
for col_def in col_defs:
# make this easy and only return instances for coltype
- coltype = col_def['type']
+ coltype = col_def["type"]
if not isinstance(coltype, TypeEngine):
- col_def['type'] = coltype()
+ col_def["type"] = coltype()
return col_defs
- @deprecated('0.7', 'Call to deprecated method get_primary_keys.'
- ' Use get_pk_constraint instead.')
+ @deprecated(
+ "0.7",
+ "Call to deprecated method get_primary_keys."
+ " Use get_pk_constraint instead.",
+ )
def get_primary_keys(self, table_name, schema=None, **kw):
"""Return information about primary keys in `table_name`.
@@ -383,9 +390,9 @@ class Inspector(object):
primary key information as a list of column names.
"""
- return self.dialect.get_pk_constraint(self.bind, table_name, schema,
- info_cache=self.info_cache,
- **kw)['constrained_columns']
+ return self.dialect.get_pk_constraint(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )["constrained_columns"]
def get_pk_constraint(self, table_name, schema=None, **kw):
"""Return information about primary key constraint on `table_name`.
@@ -407,9 +414,9 @@ class Inspector(object):
use :class:`.quoted_name`.
"""
- return self.dialect.get_pk_constraint(self.bind, table_name, schema,
- info_cache=self.info_cache,
- **kw)
+ return self.dialect.get_pk_constraint(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_foreign_keys(self, table_name, schema=None, **kw):
"""Return information about foreign_keys in `table_name`.
@@ -442,9 +449,9 @@ class Inspector(object):
"""
- return self.dialect.get_foreign_keys(self.bind, table_name, schema,
- info_cache=self.info_cache,
- **kw)
+ return self.dialect.get_foreign_keys(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_indexes(self, table_name, schema=None, **kw):
"""Return information about indexes in `table_name`.
@@ -476,9 +483,9 @@ class Inspector(object):
"""
- return self.dialect.get_indexes(self.bind, table_name,
- schema,
- info_cache=self.info_cache, **kw)
+ return self.dialect.get_indexes(
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_unique_constraints(self, table_name, schema=None, **kw):
"""Return information about unique constraints in `table_name`.
@@ -504,7 +511,8 @@ class Inspector(object):
"""
return self.dialect.get_unique_constraints(
- self.bind, table_name, schema, info_cache=self.info_cache, **kw)
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_table_comment(self, table_name, schema=None, **kw):
"""Return information about the table comment for ``table_name``.
@@ -523,8 +531,8 @@ class Inspector(object):
"""
return self.dialect.get_table_comment(
- self.bind, table_name, schema, info_cache=self.info_cache,
- **kw)
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
def get_check_constraints(self, table_name, schema=None, **kw):
"""Return information about check constraints in `table_name`.
@@ -550,10 +558,12 @@ class Inspector(object):
"""
return self.dialect.get_check_constraints(
- self.bind, table_name, schema, info_cache=self.info_cache, **kw)
+ self.bind, table_name, schema, info_cache=self.info_cache, **kw
+ )
- def reflecttable(self, table, include_columns, exclude_columns=(),
- _extend_on=None):
+ def reflecttable(
+ self, table, include_columns, exclude_columns=(), _extend_on=None
+ ):
"""Given a Table object, load its internal constructs based on
introspection.
@@ -599,7 +609,8 @@ class Inspector(object):
# reflect table options, like mysql_engine
tbl_opts = self.get_table_options(
- table_name, schema, **table.dialect_kwargs)
+ table_name, schema, **table.dialect_kwargs
+ )
if tbl_opts:
# add additional kwargs to the Table if the dialect
# returned them
@@ -615,185 +626,251 @@ class Inspector(object):
cols_by_orig_name = {}
for col_d in self.get_columns(
- table_name, schema, **table.dialect_kwargs):
+ table_name, schema, **table.dialect_kwargs
+ ):
found_table = True
self._reflect_column(
- table, col_d, include_columns,
- exclude_columns, cols_by_orig_name)
+ table,
+ col_d,
+ include_columns,
+ exclude_columns,
+ cols_by_orig_name,
+ )
if not found_table:
raise exc.NoSuchTableError(table.name)
self._reflect_pk(
- table_name, schema, table, cols_by_orig_name, exclude_columns)
+ table_name, schema, table, cols_by_orig_name, exclude_columns
+ )
self._reflect_fk(
- table_name, schema, table, cols_by_orig_name,
- exclude_columns, _extend_on, reflection_options)
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ exclude_columns,
+ _extend_on,
+ reflection_options,
+ )
self._reflect_indexes(
- table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options)
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
self._reflect_unique_constraints(
- table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options)
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
self._reflect_check_constraints(
- table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options)
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
self._reflect_table_comment(
table_name, schema, table, reflection_options
)
def _reflect_column(
- self, table, col_d, include_columns,
- exclude_columns, cols_by_orig_name):
+ self, table, col_d, include_columns, exclude_columns, cols_by_orig_name
+ ):
- orig_name = col_d['name']
+ orig_name = col_d["name"]
table.dispatch.column_reflect(self, table, col_d)
# fetch name again as column_reflect is allowed to
# change it
- name = col_d['name']
- if (include_columns and name not in include_columns) \
- or (exclude_columns and name in exclude_columns):
+ name = col_d["name"]
+ if (include_columns and name not in include_columns) or (
+ exclude_columns and name in exclude_columns
+ ):
return
- coltype = col_d['type']
+ coltype = col_d["type"]
col_kw = dict(
(k, col_d[k])
for k in [
- 'nullable', 'autoincrement', 'quote', 'info', 'key',
- 'comment']
+ "nullable",
+ "autoincrement",
+ "quote",
+ "info",
+ "key",
+ "comment",
+ ]
if k in col_d
)
- if 'dialect_options' in col_d:
- col_kw.update(col_d['dialect_options'])
+ if "dialect_options" in col_d:
+ col_kw.update(col_d["dialect_options"])
colargs = []
- if col_d.get('default') is not None:
- default = col_d['default']
+ if col_d.get("default") is not None:
+ default = col_d["default"]
if isinstance(default, sql.elements.TextClause):
default = sa_schema.DefaultClause(default, _reflected=True)
elif not isinstance(default, sa_schema.FetchedValue):
default = sa_schema.DefaultClause(
- sql.text(col_d['default']), _reflected=True)
+ sql.text(col_d["default"]), _reflected=True
+ )
colargs.append(default)
- if 'sequence' in col_d:
+ if "sequence" in col_d:
self._reflect_col_sequence(col_d, colargs)
- cols_by_orig_name[orig_name] = col = \
- sa_schema.Column(name, coltype, *colargs, **col_kw)
+ cols_by_orig_name[orig_name] = col = sa_schema.Column(
+ name, coltype, *colargs, **col_kw
+ )
if col.key in table.primary_key:
col.primary_key = True
table.append_column(col)
def _reflect_col_sequence(self, col_d, colargs):
- if 'sequence' in col_d:
+ if "sequence" in col_d:
# TODO: mssql and sybase are using this.
- seq = col_d['sequence']
- sequence = sa_schema.Sequence(seq['name'], 1, 1)
- if 'start' in seq:
- sequence.start = seq['start']
- if 'increment' in seq:
- sequence.increment = seq['increment']
+ seq = col_d["sequence"]
+ sequence = sa_schema.Sequence(seq["name"], 1, 1)
+ if "start" in seq:
+ sequence.start = seq["start"]
+ if "increment" in seq:
+ sequence.increment = seq["increment"]
colargs.append(sequence)
def _reflect_pk(
- self, table_name, schema, table,
- cols_by_orig_name, exclude_columns):
+ self, table_name, schema, table, cols_by_orig_name, exclude_columns
+ ):
pk_cons = self.get_pk_constraint(
- table_name, schema, **table.dialect_kwargs)
+ table_name, schema, **table.dialect_kwargs
+ )
if pk_cons:
pk_cols = [
cols_by_orig_name[pk]
- for pk in pk_cons['constrained_columns']
+ for pk in pk_cons["constrained_columns"]
if pk in cols_by_orig_name and pk not in exclude_columns
]
# update pk constraint name
- table.primary_key.name = pk_cons.get('name')
+ table.primary_key.name = pk_cons.get("name")
# tell the PKConstraint to re-initialize
# its column collection
table.primary_key._reload(pk_cols)
def _reflect_fk(
- self, table_name, schema, table, cols_by_orig_name,
- exclude_columns, _extend_on, reflection_options):
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ exclude_columns,
+ _extend_on,
+ reflection_options,
+ ):
fkeys = self.get_foreign_keys(
- table_name, schema, **table.dialect_kwargs)
+ table_name, schema, **table.dialect_kwargs
+ )
for fkey_d in fkeys:
- conname = fkey_d['name']
+ conname = fkey_d["name"]
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
constrained_columns = [
- cols_by_orig_name[c].key
- if c in cols_by_orig_name else c
- for c in fkey_d['constrained_columns']
+ cols_by_orig_name[c].key if c in cols_by_orig_name else c
+ for c in fkey_d["constrained_columns"]
]
if exclude_columns and set(constrained_columns).intersection(
- exclude_columns):
+ exclude_columns
+ ):
continue
- referred_schema = fkey_d['referred_schema']
- referred_table = fkey_d['referred_table']
- referred_columns = fkey_d['referred_columns']
+ referred_schema = fkey_d["referred_schema"]
+ referred_table = fkey_d["referred_table"]
+ referred_columns = fkey_d["referred_columns"]
refspec = []
if referred_schema is not None:
- sa_schema.Table(referred_table, table.metadata,
- autoload=True, schema=referred_schema,
- autoload_with=self.bind,
- _extend_on=_extend_on,
- **reflection_options
- )
+ sa_schema.Table(
+ referred_table,
+ table.metadata,
+ autoload=True,
+ schema=referred_schema,
+ autoload_with=self.bind,
+ _extend_on=_extend_on,
+ **reflection_options
+ )
for column in referred_columns:
- refspec.append(".".join(
- [referred_schema, referred_table, column]))
+ refspec.append(
+ ".".join([referred_schema, referred_table, column])
+ )
else:
- sa_schema.Table(referred_table, table.metadata, autoload=True,
- autoload_with=self.bind,
- schema=sa_schema.BLANK_SCHEMA,
- _extend_on=_extend_on,
- **reflection_options
- )
+ sa_schema.Table(
+ referred_table,
+ table.metadata,
+ autoload=True,
+ autoload_with=self.bind,
+ schema=sa_schema.BLANK_SCHEMA,
+ _extend_on=_extend_on,
+ **reflection_options
+ )
for column in referred_columns:
refspec.append(".".join([referred_table, column]))
- if 'options' in fkey_d:
- options = fkey_d['options']
+ if "options" in fkey_d:
+ options = fkey_d["options"]
else:
options = {}
table.append_constraint(
- sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
- conname, link_to_name=True,
- **options))
+ sa_schema.ForeignKeyConstraint(
+ constrained_columns,
+ refspec,
+ conname,
+ link_to_name=True,
+ **options
+ )
+ )
def _reflect_indexes(
- self, table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options):
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
# Indexes
indexes = self.get_indexes(table_name, schema)
for index_d in indexes:
- name = index_d['name']
- columns = index_d['column_names']
- unique = index_d['unique']
- flavor = index_d.get('type', 'index')
- dialect_options = index_d.get('dialect_options', {})
-
- duplicates = index_d.get('duplicates_constraint')
- if include_columns and \
- not set(columns).issubset(include_columns):
+ name = index_d["name"]
+ columns = index_d["column_names"]
+ unique = index_d["unique"]
+ flavor = index_d.get("type", "index")
+ dialect_options = index_d.get("dialect_options", {})
+
+ duplicates = index_d.get("duplicates_constraint")
+ if include_columns and not set(columns).issubset(include_columns):
util.warn(
- "Omitting %s key for (%s), key covers omitted columns." %
- (flavor, ', '.join(columns)))
+ "Omitting %s key for (%s), key covers omitted columns."
+ % (flavor, ", ".join(columns))
+ )
continue
if duplicates:
continue
@@ -802,26 +879,36 @@ class Inspector(object):
idx_cols = []
for c in columns:
try:
- idx_col = cols_by_orig_name[c] \
- if c in cols_by_orig_name else table.c[c]
+ idx_col = (
+ cols_by_orig_name[c]
+ if c in cols_by_orig_name
+ else table.c[c]
+ )
except KeyError:
util.warn(
"%s key '%s' was not located in "
- "columns for table '%s'" % (
- flavor, c, table_name
- ))
+ "columns for table '%s'" % (flavor, c, table_name)
+ )
else:
idx_cols.append(idx_col)
sa_schema.Index(
- name, *idx_cols,
+ name,
+ *idx_cols,
_table=table,
- **dict(list(dialect_options.items()) + [('unique', unique)])
+ **dict(list(dialect_options.items()) + [("unique", unique)])
)
def _reflect_unique_constraints(
- self, table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options):
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
# Unique Constraints
try:
@@ -831,15 +918,14 @@ class Inspector(object):
return
for const_d in constraints:
- conname = const_d['name']
- columns = const_d['column_names']
- duplicates = const_d.get('duplicates_index')
- if include_columns and \
- not set(columns).issubset(include_columns):
+ conname = const_d["name"]
+ columns = const_d["column_names"]
+ duplicates = const_d.get("duplicates_index")
+ if include_columns and not set(columns).issubset(include_columns):
util.warn(
"Omitting unique constraint key for (%s), "
- "key covers omitted columns." %
- ', '.join(columns))
+ "key covers omitted columns." % ", ".join(columns)
+ )
continue
if duplicates:
continue
@@ -848,20 +934,32 @@ class Inspector(object):
constrained_cols = []
for c in columns:
try:
- constrained_col = cols_by_orig_name[c] \
- if c in cols_by_orig_name else table.c[c]
+ constrained_col = (
+ cols_by_orig_name[c]
+ if c in cols_by_orig_name
+ else table.c[c]
+ )
except KeyError:
util.warn(
"unique constraint key '%s' was not located in "
- "columns for table '%s'" % (c, table_name))
+ "columns for table '%s'" % (c, table_name)
+ )
else:
constrained_cols.append(constrained_col)
table.append_constraint(
- sa_schema.UniqueConstraint(*constrained_cols, name=conname))
+ sa_schema.UniqueConstraint(*constrained_cols, name=conname)
+ )
def _reflect_check_constraints(
- self, table_name, schema, table, cols_by_orig_name,
- include_columns, exclude_columns, reflection_options):
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
try:
constraints = self.get_check_constraints(table_name, schema)
except NotImplementedError:
@@ -869,14 +967,14 @@ class Inspector(object):
return
for const_d in constraints:
- table.append_constraint(
- sa_schema.CheckConstraint(**const_d))
+ table.append_constraint(sa_schema.CheckConstraint(**const_d))
def _reflect_table_comment(
- self, table_name, schema, table, reflection_options):
+ self, table_name, schema, table, reflection_options
+ ):
try:
comment_dict = self.get_table_comment(table_name, schema)
except NotImplementedError:
return
else:
- table.comment = comment_dict.get('text', None)
+ table.comment = comment_dict.get("text", None)
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index d4c862375..5ad0d2909 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -27,20 +27,25 @@ try:
# the extension is present.
def rowproxy_reconstructor(cls, state):
return safe_rowproxy_reconstructor(cls, state)
+
+
except ImportError:
+
def rowproxy_reconstructor(cls, state):
obj = cls.__new__(cls)
obj.__setstate__(state)
return obj
+
try:
from sqlalchemy.cresultproxy import BaseRowProxy
+
_baserowproxy_usecext = True
except ImportError:
_baserowproxy_usecext = False
class BaseRowProxy(object):
- __slots__ = ('_parent', '_row', '_processors', '_keymap')
+ __slots__ = ("_parent", "_row", "_processors", "_keymap")
def __init__(self, parent, row, processors, keymap):
"""RowProxy objects are constructed by ResultProxy objects."""
@@ -51,8 +56,10 @@ except ImportError:
self._keymap = keymap
def __reduce__(self):
- return (rowproxy_reconstructor,
- (self.__class__, self.__getstate__()))
+ return (
+ rowproxy_reconstructor,
+ (self.__class__, self.__getstate__()),
+ )
def values(self):
"""Return the values represented by this RowProxy as a list."""
@@ -76,8 +83,9 @@ except ImportError:
except TypeError:
if isinstance(key, slice):
l = []
- for processor, value in zip(self._processors[key],
- self._row[key]):
+ for processor, value in zip(
+ self._processors[key], self._row[key]
+ ):
if processor is None:
l.append(value)
else:
@@ -88,7 +96,8 @@ except ImportError:
if index is None:
raise exc.InvalidRequestError(
"Ambiguous column name '%s' in "
- "result set column descriptions" % obj)
+ "result set column descriptions" % obj
+ )
if processor is not None:
return processor(self._row[index])
else:
@@ -110,29 +119,29 @@ class RowProxy(BaseRowProxy):
mapped to the original Columns that produced this result set (for
results that correspond to constructed SQL expressions).
"""
+
__slots__ = ()
def __contains__(self, key):
return self._parent._has_key(key)
def __getstate__(self):
- return {
- '_parent': self._parent,
- '_row': tuple(self)
- }
+ return {"_parent": self._parent, "_row": tuple(self)}
def __setstate__(self, state):
- self._parent = parent = state['_parent']
- self._row = state['_row']
+ self._parent = parent = state["_parent"]
+ self._row = state["_row"]
self._processors = parent._processors
self._keymap = parent._keymap
__hash__ = None
def _op(self, other, op):
- return op(tuple(self), tuple(other)) \
- if isinstance(other, RowProxy) \
+ return (
+ op(tuple(self), tuple(other))
+ if isinstance(other, RowProxy)
else op(tuple(self), other)
+ )
def __lt__(self, other):
return self._op(other, operator.lt)
@@ -176,6 +185,7 @@ class RowProxy(BaseRowProxy):
def itervalues(self):
return iter(self)
+
try:
# Register RowProxy with Sequence,
# so sequence protocol is implemented
@@ -189,8 +199,13 @@ class ResultMetaData(object):
context."""
__slots__ = (
- '_keymap', 'case_sensitive', 'matched_on_name',
- '_processors', 'keys', '_orig_processors')
+ "_keymap",
+ "case_sensitive",
+ "matched_on_name",
+ "_processors",
+ "keys",
+ "_orig_processors",
+ )
def __init__(self, parent, cursor_description):
context = parent.context
@@ -200,18 +215,25 @@ class ResultMetaData(object):
self._orig_processors = None
if context.result_column_struct:
- result_columns, cols_are_ordered, textual_ordered = \
+ result_columns, cols_are_ordered, textual_ordered = (
context.result_column_struct
+ )
num_ctx_cols = len(result_columns)
else:
- result_columns = cols_are_ordered = \
- num_ctx_cols = textual_ordered = False
+ result_columns = (
+ cols_are_ordered
+ ) = num_ctx_cols = textual_ordered = False
# merge cursor.description with the column info
# present in the compiled structure, if any
raw = self._merge_cursor_description(
- context, cursor_description, result_columns,
- num_ctx_cols, cols_are_ordered, textual_ordered)
+ context,
+ cursor_description,
+ result_columns,
+ num_ctx_cols,
+ cols_are_ordered,
+ textual_ordered,
+ )
self._keymap = {}
if not _baserowproxy_usecext:
@@ -223,23 +245,20 @@ class ResultMetaData(object):
len_raw = len(raw)
- self._keymap.update([
- (elem[0], (elem[3], elem[4], elem[0]))
- for elem in raw
- ] + [
- (elem[0] - len_raw, (elem[3], elem[4], elem[0]))
- for elem in raw
- ])
+ self._keymap.update(
+ [(elem[0], (elem[3], elem[4], elem[0])) for elem in raw]
+ + [
+ (elem[0] - len_raw, (elem[3], elem[4], elem[0]))
+ for elem in raw
+ ]
+ )
# processors in key order for certain per-row
# views like __iter__ and slices
self._processors = [elem[3] for elem in raw]
# keymap by primary string...
- by_key = dict([
- (elem[2], (elem[3], elem[4], elem[0]))
- for elem in raw
- ])
+ by_key = dict([(elem[2], (elem[3], elem[4], elem[0])) for elem in raw])
# for compiled SQL constructs, copy additional lookup keys into
# the key lookup map, such as Column objects, labels,
@@ -264,29 +283,38 @@ class ResultMetaData(object):
# copy secondary elements from compiled columns
# into self._keymap, write in the potentially "ambiguous"
# element
- self._keymap.update([
- (obj_elem, by_key[elem[2]])
- for elem in raw if elem[4]
- for obj_elem in elem[4]
- ])
+ self._keymap.update(
+ [
+ (obj_elem, by_key[elem[2]])
+ for elem in raw
+ if elem[4]
+ for obj_elem in elem[4]
+ ]
+ )
# if we did a pure positional match, then reset the
# original "expression element" back to the "unambiguous"
# entry. This is a new behavior in 1.1 which impacts
# TextAsFrom but also straight compiled SQL constructs.
if not self.matched_on_name:
- self._keymap.update([
- (elem[4][0], (elem[3], elem[4], elem[0]))
- for elem in raw if elem[4]
- ])
+ self._keymap.update(
+ [
+ (elem[4][0], (elem[3], elem[4], elem[0]))
+ for elem in raw
+ if elem[4]
+ ]
+ )
else:
# no dupes - copy secondary elements from compiled
# columns into self._keymap
- self._keymap.update([
- (obj_elem, (elem[3], elem[4], elem[0]))
- for elem in raw if elem[4]
- for obj_elem in elem[4]
- ])
+ self._keymap.update(
+ [
+ (obj_elem, (elem[3], elem[4], elem[0]))
+ for elem in raw
+ if elem[4]
+ for obj_elem in elem[4]
+ ]
+ )
# update keymap with primary string names taking
# precedence
@@ -294,14 +322,19 @@ class ResultMetaData(object):
# update keymap with "translated" names (sqlite-only thing)
if not num_ctx_cols and context._translate_colname:
- self._keymap.update([
- (elem[5], self._keymap[elem[2]])
- for elem in raw if elem[5]
- ])
+ self._keymap.update(
+ [(elem[5], self._keymap[elem[2]]) for elem in raw if elem[5]]
+ )
def _merge_cursor_description(
- self, context, cursor_description, result_columns,
- num_ctx_cols, cols_are_ordered, textual_ordered):
+ self,
+ context,
+ cursor_description,
+ result_columns,
+ num_ctx_cols,
+ cols_are_ordered,
+ textual_ordered,
+ ):
"""Merge a cursor.description with compiled result column information.
There are at least four separate strategies used here, selected
@@ -357,10 +390,12 @@ class ResultMetaData(object):
case_sensitive = context.dialect.case_sensitive
- if num_ctx_cols and \
- cols_are_ordered and \
- not textual_ordered and \
- num_ctx_cols == len(cursor_description):
+ if (
+ num_ctx_cols
+ and cols_are_ordered
+ and not textual_ordered
+ and num_ctx_cols == len(cursor_description)
+ ):
self.keys = [elem[0] for elem in result_columns]
# pure positional 1-1 case; doesn't need to read
# the names from cursor.description
@@ -373,9 +408,9 @@ class ResultMetaData(object):
type_, key, cursor_description[idx][1]
),
obj,
- None
- ) for idx, (key, name, obj, type_)
- in enumerate(result_columns)
+ None,
+ )
+ for idx, (key, name, obj, type_) in enumerate(result_columns)
]
else:
# name-based or text-positional cases, where we need
@@ -383,26 +418,32 @@ class ResultMetaData(object):
if textual_ordered:
# textual positional case
raw_iterator = self._merge_textual_cols_by_position(
- context, cursor_description, result_columns)
+ context, cursor_description, result_columns
+ )
elif num_ctx_cols:
# compiled SQL with a mismatch of description cols
# vs. compiled cols, or textual w/ unordered columns
raw_iterator = self._merge_cols_by_name(
- context, cursor_description, result_columns)
+ context, cursor_description, result_columns
+ )
else:
# no compiled SQL, just a raw string
raw_iterator = self._merge_cols_by_none(
- context, cursor_description)
+ context, cursor_description
+ )
return [
(
- idx, colname, colname,
+ idx,
+ colname,
+ colname,
context.get_result_processor(
- mapped_type, colname, coltype),
- obj, untranslated)
-
- for idx, colname, mapped_type, coltype, obj, untranslated
- in raw_iterator
+ mapped_type, colname, coltype
+ ),
+ obj,
+ untranslated,
+ )
+ for idx, colname, mapped_type, coltype, obj, untranslated in raw_iterator
]
def _colnames_from_description(self, context, cursor_description):
@@ -416,10 +457,14 @@ class ResultMetaData(object):
dialect = context.dialect
case_sensitive = dialect.case_sensitive
translate_colname = context._translate_colname
- description_decoder = dialect._description_decoder \
- if dialect.description_encoding else None
- normalize_name = dialect.normalize_name \
- if dialect.requires_name_normalize else None
+ description_decoder = (
+ dialect._description_decoder
+ if dialect.description_encoding
+ else None
+ )
+ normalize_name = (
+ dialect.normalize_name if dialect.requires_name_normalize else None
+ )
untranslated = None
self.keys = []
@@ -444,20 +489,25 @@ class ResultMetaData(object):
yield idx, colname, untranslated, coltype
def _merge_textual_cols_by_position(
- self, context, cursor_description, result_columns):
+ self, context, cursor_description, result_columns
+ ):
dialect = context.dialect
num_ctx_cols = len(result_columns) if result_columns else None
if num_ctx_cols > len(cursor_description):
util.warn(
"Number of columns in textual SQL (%d) is "
- "smaller than number of columns requested (%d)" % (
- num_ctx_cols, len(cursor_description)
- ))
+ "smaller than number of columns requested (%d)"
+ % (num_ctx_cols, len(cursor_description))
+ )
seen = set()
- for idx, colname, untranslated, coltype in \
- self._colnames_from_description(context, cursor_description):
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
if idx < num_ctx_cols:
ctx_rec = result_columns[idx]
obj = ctx_rec[2]
@@ -465,7 +515,8 @@ class ResultMetaData(object):
if obj[0] in seen:
raise exc.InvalidRequestError(
"Duplicate column expression requested "
- "in textual SQL: %r" % obj[0])
+ "in textual SQL: %r" % obj[0]
+ )
seen.add(obj[0])
else:
mapped_type = sqltypes.NULLTYPE
@@ -479,8 +530,12 @@ class ResultMetaData(object):
result_map = self._create_result_map(result_columns, case_sensitive)
self.matched_on_name = True
- for idx, colname, untranslated, coltype in \
- self._colnames_from_description(context, cursor_description):
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
try:
ctx_rec = result_map[colname]
except KeyError:
@@ -493,8 +548,12 @@ class ResultMetaData(object):
def _merge_cols_by_none(self, context, cursor_description):
dialect = context.dialect
- for idx, colname, untranslated, coltype in \
- self._colnames_from_description(context, cursor_description):
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated
@classmethod
@@ -525,27 +584,28 @@ class ResultMetaData(object):
# or colummn('name') constructs to ColumnElements, or after a
# pickle/unpickle roundtrip
elif isinstance(key, expression.ColumnElement):
- if key._label and (
- key._label
- if self.case_sensitive
- else key._label.lower()) in map:
- result = map[key._label
- if self.case_sensitive
- else key._label.lower()]
- elif hasattr(key, 'name') and (
- key.name
- if self.case_sensitive
- else key.name.lower()) in map:
+ if (
+ key._label
+ and (key._label if self.case_sensitive else key._label.lower())
+ in map
+ ):
+ result = map[
+ key._label if self.case_sensitive else key._label.lower()
+ ]
+ elif (
+ hasattr(key, "name")
+ and (key.name if self.case_sensitive else key.name.lower())
+ in map
+ ):
# match is only on name.
- result = map[key.name
- if self.case_sensitive
- else key.name.lower()]
+ result = map[
+ key.name if self.case_sensitive else key.name.lower()
+ ]
# search extra hard to make sure this
# isn't a column/label name overlap.
# this check isn't currently available if the row
# was unpickled.
- if result is not None and \
- result[1] is not None:
+ if result is not None and result[1] is not None:
for obj in result[1]:
if key._compare_name_for_result(obj):
break
@@ -554,8 +614,9 @@ class ResultMetaData(object):
if result is None:
if raiseerr:
raise exc.NoSuchColumnError(
- "Could not locate column in row for column '%s'" %
- expression._string_or_unprintable(key))
+ "Could not locate column in row for column '%s'"
+ % expression._string_or_unprintable(key)
+ )
else:
return None
else:
@@ -580,34 +641,35 @@ class ResultMetaData(object):
if index is None:
raise exc.InvalidRequestError(
"Ambiguous column name '%s' in "
- "result set column descriptions" % obj)
+ "result set column descriptions" % obj
+ )
return operator.itemgetter(index)
def __getstate__(self):
return {
- '_pickled_keymap': dict(
+ "_pickled_keymap": dict(
(key, index)
for key, (processor, obj, index) in self._keymap.items()
if isinstance(key, util.string_types + util.int_types)
),
- 'keys': self.keys,
+ "keys": self.keys,
"case_sensitive": self.case_sensitive,
- "matched_on_name": self.matched_on_name
+ "matched_on_name": self.matched_on_name,
}
def __setstate__(self, state):
# the row has been processed at pickling time so we don't need any
# processor anymore
- self._processors = [None for _ in range(len(state['keys']))]
+ self._processors = [None for _ in range(len(state["keys"]))]
self._keymap = keymap = {}
- for key, index in state['_pickled_keymap'].items():
+ for key, index in state["_pickled_keymap"].items():
# not preserving "obj" here, unfortunately our
# proxy comparison fails with the unpickle
keymap[key] = (None, None, index)
- self.keys = state['keys']
- self.case_sensitive = state['case_sensitive']
- self.matched_on_name = state['matched_on_name']
+ self.keys = state["keys"]
+ self.case_sensitive = state["case_sensitive"]
+ self.matched_on_name = state["matched_on_name"]
class ResultProxy(object):
@@ -643,8 +705,9 @@ class ResultProxy(object):
self.dialect = context.dialect
self.cursor = self._saved_cursor = context.cursor
self.connection = context.root_connection
- self._echo = self.connection._echo and \
- context.engine._should_log_debug()
+ self._echo = (
+ self.connection._echo and context.engine._should_log_debug()
+ )
self._init_metadata()
def _getter(self, key, raiseerr=True):
@@ -666,18 +729,22 @@ class ResultProxy(object):
def _init_metadata(self):
cursor_description = self._cursor_description()
if cursor_description is not None:
- if self.context.compiled and \
- 'compiled_cache' in self.context.execution_options:
+ if (
+ self.context.compiled
+ and "compiled_cache" in self.context.execution_options
+ ):
if self.context.compiled._cached_metadata:
self._metadata = self.context.compiled._cached_metadata
else:
- self._metadata = self.context.compiled._cached_metadata = \
- ResultMetaData(self, cursor_description)
+ self._metadata = (
+ self.context.compiled._cached_metadata
+ ) = ResultMetaData(self, cursor_description)
else:
self._metadata = ResultMetaData(self, cursor_description)
if self._echo:
self.context.engine.logger.debug(
- "Col %r", tuple(x[0] for x in cursor_description))
+ "Col %r", tuple(x[0] for x in cursor_description)
+ )
def keys(self):
"""Return the current set of string keys for rows."""
@@ -731,7 +798,8 @@ class ResultProxy(object):
return self.context.rowcount
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None, self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
@property
def lastrowid(self):
@@ -753,8 +821,8 @@ class ResultProxy(object):
return self._saved_cursor.lastrowid
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self._saved_cursor, self.context)
+ e, None, None, self._saved_cursor, self.context
+ )
@property
def returns_rows(self):
@@ -913,17 +981,18 @@ class ResultProxy(object):
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isinsert:
raise exc.InvalidRequestError(
- "Statement is not an insert() "
- "expression construct.")
+ "Statement is not an insert() " "expression construct."
+ )
elif self.context._is_explicit_returning:
raise exc.InvalidRequestError(
"Can't call inserted_primary_key "
"when returning() "
- "is used.")
+ "is used."
+ )
return self.context.inserted_primary_key
@@ -938,12 +1007,12 @@ class ResultProxy(object):
"""
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isupdate:
raise exc.InvalidRequestError(
- "Statement is not an update() "
- "expression construct.")
+ "Statement is not an update() " "expression construct."
+ )
elif self.context.executemany:
return self.context.compiled_parameters
else:
@@ -960,12 +1029,12 @@ class ResultProxy(object):
"""
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isinsert:
raise exc.InvalidRequestError(
- "Statement is not an insert() "
- "expression construct.")
+ "Statement is not an insert() " "expression construct."
+ )
elif self.context.executemany:
return self.context.compiled_parameters
else:
@@ -1013,12 +1082,13 @@ class ResultProxy(object):
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isinsert and not self.context.isupdate:
raise exc.InvalidRequestError(
"Statement is not an insert() or update() "
- "expression construct.")
+ "expression construct."
+ )
return self.context.postfetch_cols
def prefetch_cols(self):
@@ -1035,12 +1105,13 @@ class ResultProxy(object):
if not self.context.compiled:
raise exc.InvalidRequestError(
- "Statement is not a compiled "
- "expression construct.")
+ "Statement is not a compiled " "expression construct."
+ )
elif not self.context.isinsert and not self.context.isupdate:
raise exc.InvalidRequestError(
"Statement is not an insert() or update() "
- "expression construct.")
+ "expression construct."
+ )
return self.context.prefetch_cols
def supports_sane_rowcount(self):
@@ -1086,7 +1157,7 @@ class ResultProxy(object):
if self._metadata is None:
raise exc.ResourceClosedError(
"This result object does not return rows. "
- "It has been closed automatically.",
+ "It has been closed automatically."
)
elif self.closed:
raise exc.ResourceClosedError("This result object is closed.")
@@ -1106,8 +1177,9 @@ class ResultProxy(object):
l.append(process_row(metadata, row, processors, keymap))
return l
else:
- return [process_row(metadata, row, processors, keymap)
- for row in rows]
+ return [
+ process_row(metadata, row, processors, keymap) for row in rows
+ ]
def fetchall(self):
"""Fetch all rows, just like DB-API ``cursor.fetchall()``.
@@ -1132,8 +1204,8 @@ class ResultProxy(object):
return l
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
def fetchmany(self, size=None):
"""Fetch many rows, just like DB-API
@@ -1161,8 +1233,8 @@ class ResultProxy(object):
return l
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
def fetchone(self):
"""Fetch one row, just like DB-API ``cursor.fetchone()``.
@@ -1190,8 +1262,8 @@ class ResultProxy(object):
return None
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
def first(self):
"""Fetch the first row and then close the result set unconditionally.
@@ -1209,8 +1281,8 @@ class ResultProxy(object):
row = self._fetchone_impl()
except BaseException as e:
self.connection._handle_dbapi_exception(
- e, None, None,
- self.cursor, self.context)
+ e, None, None, self.cursor, self.context
+ )
try:
if row is not None:
@@ -1268,7 +1340,8 @@ class BufferedRowResultProxy(ResultProxy):
def _init_metadata(self):
self._max_row_buffer = self.context.execution_options.get(
- 'max_row_buffer', None)
+ "max_row_buffer", None
+ )
self.__buffer_rows()
super(BufferedRowResultProxy, self)._init_metadata()
@@ -1284,13 +1357,13 @@ class BufferedRowResultProxy(ResultProxy):
50: 100,
100: 250,
250: 500,
- 500: 1000
+ 500: 1000,
}
def __buffer_rows(self):
if self.cursor is None:
return
- size = getattr(self, '_bufsize', 1)
+ size = getattr(self, "_bufsize", 1)
self.__rowbuffer = collections.deque(self.cursor.fetchmany(size))
self._bufsize = self.size_growth.get(size, size)
if self._max_row_buffer is not None:
@@ -1385,8 +1458,9 @@ class BufferedColumnRow(RowProxy):
row[index] = processor(row[index])
index += 1
row = tuple(row)
- super(BufferedColumnRow, self).__init__(parent, row,
- processors, keymap)
+ super(BufferedColumnRow, self).__init__(
+ parent, row, processors, keymap
+ )
class BufferedColumnResultProxy(ResultProxy):
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index d4f5185de..4aecb9537 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -51,18 +51,20 @@ class DefaultEngineStrategy(EngineStrategy):
plugins = u._instantiate_plugins(kwargs)
- u.query.pop('plugin', None)
- kwargs.pop('plugins', None)
+ u.query.pop("plugin", None)
+ kwargs.pop("plugins", None)
entrypoint = u._get_entrypoint()
dialect_cls = entrypoint.get_dialect_cls(u)
- if kwargs.pop('_coerce_config', False):
+ if kwargs.pop("_coerce_config", False):
+
def pop_kwarg(key, default=None):
value = kwargs.pop(key, default)
if key in dialect_cls.engine_config_types:
value = dialect_cls.engine_config_types[key](value)
return value
+
else:
pop_kwarg = kwargs.pop
@@ -72,7 +74,7 @@ class DefaultEngineStrategy(EngineStrategy):
if k in kwargs:
dialect_args[k] = pop_kwarg(k)
- dbapi = kwargs.pop('module', None)
+ dbapi = kwargs.pop("module", None)
if dbapi is None:
dbapi_args = {}
for k in util.get_func_kwargs(dialect_cls.dbapi):
@@ -80,7 +82,7 @@ class DefaultEngineStrategy(EngineStrategy):
dbapi_args[k] = pop_kwarg(k)
dbapi = dialect_cls.dbapi(**dbapi_args)
- dialect_args['dbapi'] = dbapi
+ dialect_args["dbapi"] = dbapi
for plugin in plugins:
plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
@@ -90,41 +92,43 @@ class DefaultEngineStrategy(EngineStrategy):
# assemble connection arguments
(cargs, cparams) = dialect.create_connect_args(u)
- cparams.update(pop_kwarg('connect_args', {}))
+ cparams.update(pop_kwarg("connect_args", {}))
cargs = list(cargs) # allow mutability
# look for existing pool or create
- pool = pop_kwarg('pool', None)
+ pool = pop_kwarg("pool", None)
if pool is None:
+
def connect(connection_record=None):
if dialect._has_events:
for fn in dialect.dispatch.do_connect:
connection = fn(
- dialect, connection_record, cargs, cparams)
+ dialect, connection_record, cargs, cparams
+ )
if connection is not None:
return connection
return dialect.connect(*cargs, **cparams)
- creator = pop_kwarg('creator', connect)
+ creator = pop_kwarg("creator", connect)
- poolclass = pop_kwarg('poolclass', None)
+ poolclass = pop_kwarg("poolclass", None)
if poolclass is None:
poolclass = dialect_cls.get_pool_class(u)
- pool_args = {
- 'dialect': dialect
- }
+ pool_args = {"dialect": dialect}
# consume pool arguments from kwargs, translating a few of
# the arguments
- translate = {'logging_name': 'pool_logging_name',
- 'echo': 'echo_pool',
- 'timeout': 'pool_timeout',
- 'recycle': 'pool_recycle',
- 'events': 'pool_events',
- 'use_threadlocal': 'pool_threadlocal',
- 'reset_on_return': 'pool_reset_on_return',
- 'pre_ping': 'pool_pre_ping',
- 'use_lifo': 'pool_use_lifo'}
+ translate = {
+ "logging_name": "pool_logging_name",
+ "echo": "echo_pool",
+ "timeout": "pool_timeout",
+ "recycle": "pool_recycle",
+ "events": "pool_events",
+ "use_threadlocal": "pool_threadlocal",
+ "reset_on_return": "pool_reset_on_return",
+ "pre_ping": "pool_pre_ping",
+ "use_lifo": "pool_use_lifo",
+ }
for k in util.get_cls_kwargs(poolclass):
tk = translate.get(k, k)
if tk in kwargs:
@@ -149,7 +153,7 @@ class DefaultEngineStrategy(EngineStrategy):
if k in kwargs:
engine_args[k] = pop_kwarg(k)
- _initialize = kwargs.pop('_initialize', True)
+ _initialize = kwargs.pop("_initialize", True)
# all kwargs should be consumed
if kwargs:
@@ -157,32 +161,40 @@ class DefaultEngineStrategy(EngineStrategy):
"Invalid argument(s) %s sent to create_engine(), "
"using configuration %s/%s/%s. Please check that the "
"keyword arguments are appropriate for this combination "
- "of components." % (','.join("'%s'" % k for k in kwargs),
- dialect.__class__.__name__,
- pool.__class__.__name__,
- engineclass.__name__))
+ "of components."
+ % (
+ ",".join("'%s'" % k for k in kwargs),
+ dialect.__class__.__name__,
+ pool.__class__.__name__,
+ engineclass.__name__,
+ )
+ )
engine = engineclass(pool, dialect, u, **engine_args)
if _initialize:
do_on_connect = dialect.on_connect()
if do_on_connect:
+
def on_connect(dbapi_connection, connection_record):
conn = getattr(
- dbapi_connection, '_sqla_unwrap', dbapi_connection)
+ dbapi_connection, "_sqla_unwrap", dbapi_connection
+ )
if conn is None:
return
do_on_connect(conn)
- event.listen(pool, 'first_connect', on_connect)
- event.listen(pool, 'connect', on_connect)
+ event.listen(pool, "first_connect", on_connect)
+ event.listen(pool, "connect", on_connect)
def first_connect(dbapi_connection, connection_record):
- c = base.Connection(engine, connection=dbapi_connection,
- _has_events=False)
+ c = base.Connection(
+ engine, connection=dbapi_connection, _has_events=False
+ )
c._execution_options = util.immutabledict()
dialect.initialize(c)
- event.listen(pool, 'first_connect', first_connect, once=True)
+
+ event.listen(pool, "first_connect", first_connect, once=True)
dialect_cls.engine_created(engine)
if entrypoint is not dialect_cls:
@@ -197,18 +209,20 @@ class DefaultEngineStrategy(EngineStrategy):
class PlainEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring a regular Engine."""
- name = 'plain'
+ name = "plain"
engine_cls = base.Engine
+
PlainEngineStrategy()
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring an Engine with threadlocal behavior."""
- name = 'threadlocal'
+ name = "threadlocal"
engine_cls = threadlocal.TLEngine
+
ThreadLocalEngineStrategy()
@@ -220,7 +234,7 @@ class MockEngineStrategy(EngineStrategy):
"""
- name = 'mock'
+ name = "mock"
def create(self, name_or_url, executor, **kwargs):
# create url.URL object
@@ -245,7 +259,7 @@ class MockEngineStrategy(EngineStrategy):
self.execute = execute
engine = property(lambda s: s)
- dialect = property(attrgetter('_dialect'))
+ dialect = property(attrgetter("_dialect"))
name = property(lambda s: s._dialect.name)
schema_for_object = schema._schema_getter(None)
@@ -258,29 +272,35 @@ class MockEngineStrategy(EngineStrategy):
def compiler(self, statement, parameters, **kwargs):
return self._dialect.compiler(
- statement, parameters, engine=self, **kwargs)
+ statement, parameters, engine=self, **kwargs
+ )
def create(self, entity, **kwargs):
- kwargs['checkfirst'] = False
+ kwargs["checkfirst"] = False
from sqlalchemy.engine import ddl
- ddl.SchemaGenerator(
- self.dialect, self, **kwargs).traverse_single(entity)
+ ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single(
+ entity
+ )
def drop(self, entity, **kwargs):
- kwargs['checkfirst'] = False
+ kwargs["checkfirst"] = False
from sqlalchemy.engine import ddl
- ddl.SchemaDropper(
- self.dialect, self, **kwargs).traverse_single(entity)
- def _run_visitor(self, visitorcallable, element,
- connection=None,
- **kwargs):
- kwargs['checkfirst'] = False
- visitorcallable(self.dialect, self,
- **kwargs).traverse_single(element)
+ ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single(
+ entity
+ )
+
+ def _run_visitor(
+ self, visitorcallable, element, connection=None, **kwargs
+ ):
+ kwargs["checkfirst"] = False
+ visitorcallable(self.dialect, self, **kwargs).traverse_single(
+ element
+ )
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
+
MockEngineStrategy()
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py
index 0ec1f9613..5b2bdabc0 100644
--- a/lib/sqlalchemy/engine/threadlocal.py
+++ b/lib/sqlalchemy/engine/threadlocal.py
@@ -19,7 +19,6 @@ import weakref
class TLConnection(base.Connection):
-
def __init__(self, *arg, **kw):
super(TLConnection, self).__init__(*arg, **kw)
self.__opencount = 0
@@ -43,6 +42,7 @@ class TLEngine(base.Engine):
transactions.
"""
+
_tl_connection_cls = TLConnection
def __init__(self, *args, **kwargs):
@@ -50,7 +50,7 @@ class TLEngine(base.Engine):
self._connections = util.threading.local()
def contextual_connect(self, **kw):
- if not hasattr(self._connections, 'conn'):
+ if not hasattr(self._connections, "conn"):
connection = None
else:
connection = self._connections.conn()
@@ -60,29 +60,31 @@ class TLEngine(base.Engine):
# or not connection.connection.is_valid:
connection = self._tl_connection_cls(
self,
- self._wrap_pool_connect(
- self.pool.connect, connection),
- **kw)
+ self._wrap_pool_connect(self.pool.connect, connection),
+ **kw
+ )
self._connections.conn = weakref.ref(connection)
return connection._increment_connect()
def begin_twophase(self, xid=None):
- if not hasattr(self._connections, 'trans'):
+ if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(
- self.contextual_connect().begin_twophase(xid=xid))
+ self.contextual_connect().begin_twophase(xid=xid)
+ )
return self
def begin_nested(self):
- if not hasattr(self._connections, 'trans'):
+ if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(
- self.contextual_connect().begin_nested())
+ self.contextual_connect().begin_nested()
+ )
return self
def begin(self):
- if not hasattr(self._connections, 'trans'):
+ if not hasattr(self._connections, "trans"):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin())
return self
@@ -97,21 +99,27 @@ class TLEngine(base.Engine):
self.rollback()
def prepare(self):
- if not hasattr(self._connections, 'trans') or \
- not self._connections.trans:
+ if (
+ not hasattr(self._connections, "trans")
+ or not self._connections.trans
+ ):
return
self._connections.trans[-1].prepare()
def commit(self):
- if not hasattr(self._connections, 'trans') or \
- not self._connections.trans:
+ if (
+ not hasattr(self._connections, "trans")
+ or not self._connections.trans
+ ):
return
trans = self._connections.trans.pop(-1)
trans.commit()
def rollback(self):
- if not hasattr(self._connections, 'trans') or \
- not self._connections.trans:
+ if (
+ not hasattr(self._connections, "trans")
+ or not self._connections.trans
+ ):
return
trans = self._connections.trans.pop(-1)
trans.rollback()
@@ -122,9 +130,11 @@ class TLEngine(base.Engine):
@property
def closed(self):
- return not hasattr(self._connections, 'conn') or \
- self._connections.conn() is None or \
- self._connections.conn().closed
+ return (
+ not hasattr(self._connections, "conn")
+ or self._connections.conn() is None
+ or self._connections.conn().closed
+ )
def close(self):
if not self.closed:
@@ -135,4 +145,4 @@ class TLEngine(base.Engine):
self._connections.trans = []
def __repr__(self):
- return 'TLEngine(%r)' % self.url
+ return "TLEngine(%r)" % self.url
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index 1662efe20..e92e57b8e 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -50,8 +50,16 @@ class URL(object):
"""
- def __init__(self, drivername, username=None, password=None,
- host=None, port=None, database=None, query=None):
+ def __init__(
+ self,
+ drivername,
+ username=None,
+ password=None,
+ host=None,
+ port=None,
+ database=None,
+ query=None,
+ ):
self.drivername = drivername
self.username = username
self.password_original = password
@@ -68,26 +76,26 @@ class URL(object):
if self.username is not None:
s += _rfc_1738_quote(self.username)
if self.password is not None:
- s += ':' + ('***' if hide_password
- else _rfc_1738_quote(self.password))
+ s += ":" + (
+ "***" if hide_password else _rfc_1738_quote(self.password)
+ )
s += "@"
if self.host is not None:
- if ':' in self.host:
+ if ":" in self.host:
s += "[%s]" % self.host
else:
s += self.host
if self.port is not None:
- s += ':' + str(self.port)
+ s += ":" + str(self.port)
if self.database is not None:
- s += '/' + self.database
+ s += "/" + self.database
if self.query:
keys = list(self.query)
keys.sort()
- s += '?' + "&".join(
- "%s=%s" % (
- k,
- element
- ) for k in keys for element in util.to_list(self.query[k])
+ s += "?" + "&".join(
+ "%s=%s" % (k, element)
+ for k in keys
+ for element in util.to_list(self.query[k])
)
return s
@@ -101,14 +109,15 @@ class URL(object):
return hash(str(self))
def __eq__(self, other):
- return \
- isinstance(other, URL) and \
- self.drivername == other.drivername and \
- self.username == other.username and \
- self.password == other.password and \
- self.host == other.host and \
- self.database == other.database and \
- self.query == other.query
+ return (
+ isinstance(other, URL)
+ and self.drivername == other.drivername
+ and self.username == other.username
+ and self.password == other.password
+ and self.host == other.host
+ and self.database == other.database
+ and self.query == other.query
+ )
@property
def password(self):
@@ -122,20 +131,20 @@ class URL(object):
self.password_original = password
def get_backend_name(self):
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
return self.drivername
else:
- return self.drivername.split('+')[0]
+ return self.drivername.split("+")[0]
def get_driver_name(self):
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
return self.get_dialect().driver
else:
- return self.drivername.split('+')[1]
+ return self.drivername.split("+")[1]
def _instantiate_plugins(self, kwargs):
- plugin_names = util.to_list(self.query.get('plugin', ()))
- plugin_names += kwargs.get('plugins', [])
+ plugin_names = util.to_list(self.query.get("plugin", ()))
+ plugin_names += kwargs.get("plugins", [])
return [
plugins.load(plugin_name)(self, kwargs)
@@ -149,17 +158,19 @@ class URL(object):
returned class implements the get_dialect_cls() method.
"""
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
name = self.drivername
else:
- name = self.drivername.replace('+', '.')
+ name = self.drivername.replace("+", ".")
cls = registry.load(name)
# check for legacy dialects that
# would return a module with 'dialect' as the
# actual class
- if hasattr(cls, 'dialect') and \
- isinstance(cls.dialect, type) and \
- issubclass(cls.dialect, Dialect):
+ if (
+ hasattr(cls, "dialect")
+ and isinstance(cls.dialect, type)
+ and issubclass(cls.dialect, Dialect)
+ ):
return cls.dialect
else:
return cls
@@ -187,7 +198,7 @@ class URL(object):
"""
translated = {}
- attribute_names = ['host', 'database', 'username', 'password', 'port']
+ attribute_names = ["host", "database", "username", "password", "port"]
for sname in attribute_names:
if names:
name = names.pop(0)
@@ -214,7 +225,8 @@ def make_url(name_or_url):
def _parse_rfc1738_args(name):
- pattern = re.compile(r'''
+ pattern = re.compile(
+ r"""
(?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
@@ -228,21 +240,23 @@ def _parse_rfc1738_args(name):
(?::(?P<port>[^/]*))?
)?
(?:/(?P<database>.*))?
- ''', re.X)
+ """,
+ re.X,
+ )
m = pattern.match(name)
if m is not None:
components = m.groupdict()
- if components['database'] is not None:
- tokens = components['database'].split('?', 2)
- components['database'] = tokens[0]
+ if components["database"] is not None:
+ tokens = components["database"].split("?", 2)
+ components["database"] = tokens[0]
if len(tokens) > 1:
query = {}
for key, value in util.parse_qsl(tokens[1]):
if util.py2k:
- key = key.encode('ascii')
+ key = key.encode("ascii")
if key in query:
query[key] = util.to_list(query[key])
query[key].append(value)
@@ -252,26 +266,27 @@ def _parse_rfc1738_args(name):
query = None
else:
query = None
- components['query'] = query
+ components["query"] = query
- if components['username'] is not None:
- components['username'] = _rfc_1738_unquote(components['username'])
+ if components["username"] is not None:
+ components["username"] = _rfc_1738_unquote(components["username"])
- if components['password'] is not None:
- components['password'] = _rfc_1738_unquote(components['password'])
+ if components["password"] is not None:
+ components["password"] = _rfc_1738_unquote(components["password"])
- ipv4host = components.pop('ipv4host')
- ipv6host = components.pop('ipv6host')
- components['host'] = ipv4host or ipv6host
- name = components.pop('name')
+ ipv4host = components.pop("ipv4host")
+ ipv6host = components.pop("ipv6host")
+ components["host"] = ipv4host or ipv6host
+ name = components.pop("name")
return URL(name, **components)
else:
raise exc.ArgumentError(
- "Could not parse rfc1738 URL from string '%s'" % name)
+ "Could not parse rfc1738 URL from string '%s'" % name
+ )
def _rfc_1738_quote(text):
- return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text)
+ return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text)
def _rfc_1738_unquote(text):
@@ -279,7 +294,7 @@ def _rfc_1738_unquote(text):
def _parse_keyvalue_args(name):
- m = re.match(r'(\w+)://(.*)', name)
+ m = re.match(r"(\w+)://(.*)", name)
if m is not None:
(name, args) = m.group(1, 2)
opts = dict(util.parse_qsl(args))
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py
index 17bc9a3b4..76bb8f4b5 100644
--- a/lib/sqlalchemy/engine/util.py
+++ b/lib/sqlalchemy/engine/util.py
@@ -46,28 +46,34 @@ def py_fallback():
elif len(multiparams) == 1:
zero = multiparams[0]
if isinstance(zero, (list, tuple)):
- if not zero or hasattr(zero[0], '__iter__') and \
- not hasattr(zero[0], 'strip'):
+ if (
+ not zero
+ or hasattr(zero[0], "__iter__")
+ and not hasattr(zero[0], "strip")
+ ):
# execute(stmt, [{}, {}, {}, ...])
# execute(stmt, [(), (), (), ...])
return zero
else:
# execute(stmt, ("value", "value"))
return [zero]
- elif hasattr(zero, 'keys'):
+ elif hasattr(zero, "keys"):
# execute(stmt, {"key":"value"})
return [zero]
else:
# execute(stmt, "value")
return [[zero]]
else:
- if hasattr(multiparams[0], '__iter__') and \
- not hasattr(multiparams[0], 'strip'):
+ if hasattr(multiparams[0], "__iter__") and not hasattr(
+ multiparams[0], "strip"
+ ):
return multiparams
else:
return [multiparams]
return locals()
+
+
try:
from sqlalchemy.cutils import _distill_params
except ImportError:
diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py
index acfacc233..f9e04503c 100644
--- a/lib/sqlalchemy/event/api.py
+++ b/lib/sqlalchemy/event/api.py
@@ -14,8 +14,8 @@ from .. import util, exc
from .base import _registrars
from .registry import _EventKey
-CANCEL = util.symbol('CANCEL')
-NO_RETVAL = util.symbol('NO_RETVAL')
+CANCEL = util.symbol("CANCEL")
+NO_RETVAL = util.symbol("NO_RETVAL")
def _event_key(target, identifier, fn):
@@ -24,8 +24,9 @@ def _event_key(target, identifier, fn):
if tgt is not None:
return _EventKey(target, identifier, fn, tgt)
else:
- raise exc.InvalidRequestError("No such event '%s' for target '%s'" %
- (identifier, target))
+ raise exc.InvalidRequestError(
+ "No such event '%s' for target '%s'" % (identifier, target)
+ )
def listen(target, identifier, fn, *args, **kw):
@@ -120,9 +121,11 @@ def listens_for(target, identifier, *args, **kw):
:func:`.listen` - general description of event listening
"""
+
def decorate(fn):
listen(target, identifier, fn, *args, **kw)
return fn
+
return decorate
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py
index c33ec82ff..31a9f28ca 100644
--- a/lib/sqlalchemy/event/attr.py
+++ b/lib/sqlalchemy/event/attr.py
@@ -41,7 +41,7 @@ import collections
class RefCollection(util.MemoizedSlots):
- __slots__ = 'ref',
+ __slots__ = ("ref",)
def _memoized_attr_ref(self):
return weakref.ref(self, registry._collection_gced)
@@ -67,20 +67,27 @@ class _empty_collection(object):
class _ClsLevelDispatch(RefCollection):
"""Class-level events on :class:`._Dispatch` classes."""
- __slots__ = ('name', 'arg_names', 'has_kw',
- 'legacy_signatures', '_clslevel', '__weakref__')
+ __slots__ = (
+ "name",
+ "arg_names",
+ "has_kw",
+ "legacy_signatures",
+ "_clslevel",
+ "__weakref__",
+ )
def __init__(self, parent_dispatch_cls, fn):
self.name = fn.__name__
argspec = util.inspect_getargspec(fn)
self.arg_names = argspec.args[1:]
self.has_kw = bool(argspec.keywords)
- self.legacy_signatures = list(reversed(
- sorted(
- getattr(fn, '_legacy_signatures', []),
- key=lambda s: s[0]
+ self.legacy_signatures = list(
+ reversed(
+ sorted(
+ getattr(fn, "_legacy_signatures", []), key=lambda s: s[0]
+ )
)
- ))
+ )
fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn)
self._clslevel = weakref.WeakKeyDictionary()
@@ -102,15 +109,18 @@ class _ClsLevelDispatch(RefCollection):
argdict = dict(zip(self.arg_names, args))
argdict.update(kw)
return fn(**argdict)
+
return wrap_kw
def insert(self, event_key, propagate):
target = event_key.dispatch_target
- assert isinstance(target, type), \
- "Class-level Event targets must be classes."
- if not getattr(target, '_sa_propagate_class_events', True):
+ assert isinstance(
+ target, type
+ ), "Class-level Event targets must be classes."
+ if not getattr(target, "_sa_propagate_class_events", True):
raise exc.InvalidRequestError(
- "Can't assign an event directly to the %s class" % target)
+ "Can't assign an event directly to the %s class" % target
+ )
stack = [target]
while stack:
cls = stack.pop(0)
@@ -125,11 +135,13 @@ class _ClsLevelDispatch(RefCollection):
def append(self, event_key, propagate):
target = event_key.dispatch_target
- assert isinstance(target, type), \
- "Class-level Event targets must be classes."
- if not getattr(target, '_sa_propagate_class_events', True):
+ assert isinstance(
+ target, type
+ ), "Class-level Event targets must be classes."
+ if not getattr(target, "_sa_propagate_class_events", True):
raise exc.InvalidRequestError(
- "Can't assign an event directly to the %s class" % target)
+ "Can't assign an event directly to the %s class" % target
+ )
stack = [target]
while stack:
cls = stack.pop(0)
@@ -143,7 +155,7 @@ class _ClsLevelDispatch(RefCollection):
registry._stored_in_collection(event_key, self)
def _assign_cls_collection(self, target):
- if getattr(target, '_sa_propagate_class_events', True):
+ if getattr(target, "_sa_propagate_class_events", True):
self._clslevel[target] = collections.deque()
else:
self._clslevel[target] = _empty_collection()
@@ -154,11 +166,9 @@ class _ClsLevelDispatch(RefCollection):
clslevel = self._clslevel[target]
for cls in target.__mro__[1:]:
if cls in self._clslevel:
- clslevel.extend([
- fn for fn
- in self._clslevel[cls]
- if fn not in clslevel
- ])
+ clslevel.extend(
+ [fn for fn in self._clslevel[cls] if fn not in clslevel]
+ )
def remove(self, event_key):
target = event_key.dispatch_target
@@ -209,7 +219,7 @@ class _EmptyListener(_InstanceLevelDispatch):
propagate = frozenset()
listeners = ()
- __slots__ = 'parent', 'parent_listeners', 'name'
+ __slots__ = "parent", "parent_listeners", "name"
def __init__(self, parent, target_cls):
if target_cls not in parent._clslevel:
@@ -258,7 +268,7 @@ class _EmptyListener(_InstanceLevelDispatch):
class _CompoundListener(_InstanceLevelDispatch):
- __slots__ = '_exec_once_mutex', '_exec_once'
+ __slots__ = "_exec_once_mutex", "_exec_once"
def _memoized_attr__exec_once_mutex(self):
return threading.Lock()
@@ -306,8 +316,13 @@ class _ListenerCollection(_CompoundListener):
"""
__slots__ = (
- 'parent_listeners', 'parent', 'name', 'listeners',
- 'propagate', '__weakref__')
+ "parent_listeners",
+ "parent",
+ "name",
+ "listeners",
+ "propagate",
+ "__weakref__",
+ )
def __init__(self, parent, target_cls):
if target_cls not in parent._clslevel:
@@ -335,11 +350,13 @@ class _ListenerCollection(_CompoundListener):
existing_listeners = self.listeners
existing_listener_set = set(existing_listeners)
self.propagate.update(other.propagate)
- other_listeners = [l for l
- in other.listeners
- if l not in existing_listener_set
- and not only_propagate or l in self.propagate
- ]
+ other_listeners = [
+ l
+ for l in other.listeners
+ if l not in existing_listener_set
+ and not only_propagate
+ or l in self.propagate
+ ]
existing_listeners.extend(other_listeners)
@@ -368,7 +385,7 @@ class _ListenerCollection(_CompoundListener):
class _JoinedListener(_CompoundListener):
- __slots__ = 'parent', 'name', 'local', 'parent_listeners'
+ __slots__ = "parent", "name", "local", "parent_listeners"
def __init__(self, parent, name, local):
self._exec_once = False
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py
index 137aec258..c750be70a 100644
--- a/lib/sqlalchemy/event/base.py
+++ b/lib/sqlalchemy/event/base.py
@@ -26,7 +26,7 @@ _registrars = util.defaultdict(list)
def _is_event_name(name):
- return not name.startswith('_') and name != 'dispatch'
+ return not name.startswith("_") and name != "dispatch"
class _UnpickleDispatch(object):
@@ -37,8 +37,8 @@ class _UnpickleDispatch(object):
def __call__(self, _instance_cls):
for cls in _instance_cls.__mro__:
- if 'dispatch' in cls.__dict__:
- return cls.__dict__['dispatch'].dispatch._for_class(
+ if "dispatch" in cls.__dict__:
+ return cls.__dict__["dispatch"].dispatch._for_class(
_instance_cls
)
else:
@@ -67,7 +67,7 @@ class _Dispatch(object):
# In one ORM edge case, an attribute is added to _Dispatch,
# so __dict__ is used in just that case and potentially others.
- __slots__ = '_parent', '_instance_cls', '__dict__', '_empty_listeners'
+ __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners"
_empty_listener_reg = weakref.WeakKeyDictionary()
@@ -79,7 +79,9 @@ class _Dispatch(object):
try:
self._empty_listeners = self._empty_listener_reg[instance_cls]
except KeyError:
- self._empty_listeners = self._empty_listener_reg[instance_cls] = {
+ self._empty_listeners = self._empty_listener_reg[
+ instance_cls
+ ] = {
ls.name: _EmptyListener(ls, instance_cls)
for ls in parent._event_descriptors
}
@@ -122,17 +124,18 @@ class _Dispatch(object):
:class:`._Dispatch` objects.
"""
- if '_joined_dispatch_cls' not in self.__class__.__dict__:
+ if "_joined_dispatch_cls" not in self.__class__.__dict__:
cls = type(
"Joined%s" % self.__class__.__name__,
- (_JoinedDispatcher, ), {'__slots__': self._event_names}
+ (_JoinedDispatcher,),
+ {"__slots__": self._event_names},
)
self.__class__._joined_dispatch_cls = cls
return self._joined_dispatch_cls(self, other)
def __reduce__(self):
- return _UnpickleDispatch(), (self._instance_cls, )
+ return _UnpickleDispatch(), (self._instance_cls,)
def _update(self, other, only_propagate=True):
"""Populate from the listeners in another :class:`_Dispatch`
@@ -140,8 +143,9 @@ class _Dispatch(object):
for ls in other._event_descriptors:
if isinstance(ls, _EmptyListener):
continue
- getattr(self, ls.name).\
- for_modify(self)._update(ls, only_propagate=only_propagate)
+ getattr(self, ls.name).for_modify(self)._update(
+ ls, only_propagate=only_propagate
+ )
def _clear(self):
for ls in self._event_descriptors:
@@ -164,14 +168,15 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
# there's all kinds of ways to do this,
# i.e. make a Dispatch class that shares the '_listen' method
# of the Event class, this is the straight monkeypatch.
- if hasattr(cls, 'dispatch'):
+ if hasattr(cls, "dispatch"):
dispatch_base = cls.dispatch.__class__
else:
dispatch_base = _Dispatch
event_names = [k for k in dict_ if _is_event_name(k)]
- dispatch_cls = type("%sDispatch" % classname,
- (dispatch_base, ), {'__slots__': event_names})
+ dispatch_cls = type(
+ "%sDispatch" % classname, (dispatch_base,), {"__slots__": event_names}
+ )
dispatch_cls._event_names = event_names
@@ -186,7 +191,7 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
setattr(dispatch_inst, ls.name, ls)
dispatch_cls._event_names.append(ls.name)
- if getattr(cls, '_dispatch_target', None):
+ if getattr(cls, "_dispatch_target", None):
cls._dispatch_target.dispatch = dispatcher(cls)
@@ -221,12 +226,14 @@ class Events(util.with_metaclass(_EventMeta, object)):
# Mapper, ClassManager, Session override this to
# also accept classes, scoped_sessions, sessionmakers, etc.
- if hasattr(target, 'dispatch'):
+ if hasattr(target, "dispatch"):
if (
dispatch_is(cls.dispatch.__class__)
or dispatch_is(type, cls.dispatch.__class__)
- or (dispatch_is(_JoinedDispatcher)
- and dispatch_parent_is(cls.dispatch.__class__))
+ or (
+ dispatch_is(_JoinedDispatcher)
+ and dispatch_parent_is(cls.dispatch.__class__)
+ )
):
return target
@@ -246,7 +253,7 @@ class Events(util.with_metaclass(_EventMeta, object)):
class _JoinedDispatcher(object):
"""Represent a connection between two _Dispatch objects."""
- __slots__ = 'local', 'parent', '_instance_cls'
+ __slots__ = "local", "parent", "_instance_cls"
def __init__(self, local, parent):
self.local = local
@@ -281,5 +288,5 @@ class dispatcher(object):
def __get__(self, obj, cls):
if obj is None:
return self.dispatch
- obj.__dict__['dispatch'] = disp = self.dispatch._for_instance(obj)
+ obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj)
return disp
diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py
index 1883070f4..c30b922fd 100644
--- a/lib/sqlalchemy/event/legacy.py
+++ b/lib/sqlalchemy/event/legacy.py
@@ -15,10 +15,11 @@ from .. import util
def _legacy_signature(since, argnames, converter=None):
def leg(fn):
- if not hasattr(fn, '_legacy_signatures'):
+ if not hasattr(fn, "_legacy_signatures"):
fn._legacy_signatures = []
fn._legacy_signatures.append((since, argnames, converter))
return fn
+
return leg
@@ -30,15 +31,18 @@ def _wrap_fn_for_legacy(dispatch_collection, fn, argspec):
else:
has_kw = False
- if len(argnames) == len(argspec.args) \
- and has_kw is bool(argspec.keywords):
+ if len(argnames) == len(argspec.args) and has_kw is bool(
+ argspec.keywords
+ ):
if conv:
assert not has_kw
def wrap_leg(*args):
return fn(*conv(*args))
+
else:
+
def wrap_leg(*args, **kw):
argdict = dict(zip(dispatch_collection.arg_names, args))
args = [argdict[name] for name in argnames]
@@ -46,16 +50,14 @@ def _wrap_fn_for_legacy(dispatch_collection, fn, argspec):
return fn(*args, **kw)
else:
return fn(*args)
+
return wrap_leg
else:
return fn
def _indent(text, indent):
- return "\n".join(
- indent + line
- for line in text.split("\n")
- )
+ return "\n".join(indent + line for line in text.split("\n"))
def _standard_listen_example(dispatch_collection, sample_target, fn):
@@ -64,10 +66,13 @@ def _standard_listen_example(dispatch_collection, sample_target, fn):
"%(arg)s = kw['%(arg)s']" % {"arg": arg}
for arg in dispatch_collection.arg_names[0:2]
),
- " ")
+ " ",
+ )
if dispatch_collection.legacy_signatures:
- current_since = max(since for since, args, conv
- in dispatch_collection.legacy_signatures)
+ current_since = max(
+ since
+ for since, args, conv in dispatch_collection.legacy_signatures
+ )
else:
current_since = None
text = (
@@ -82,7 +87,6 @@ def _standard_listen_example(dispatch_collection, sample_target, fn):
if len(dispatch_collection.arg_names) > 3:
text += (
-
"\n# named argument style (new in 0.9)\n"
"@event.listens_for("
"%(sample_target)s, '%(event_name)s', named=True)\n"
@@ -93,13 +97,14 @@ def _standard_listen_example(dispatch_collection, sample_target, fn):
)
text %= {
- "current_since": " (arguments as of %s)" %
- current_since if current_since else "",
+ "current_since": " (arguments as of %s)" % current_since
+ if current_since
+ else "",
"event_name": fn.__name__,
"has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "",
"named_event_arguments": ", ".join(dispatch_collection.arg_names),
"example_kw_arg": example_kw_arg,
- "sample_target": sample_target
+ "sample_target": sample_target,
}
return text
@@ -113,13 +118,15 @@ def _legacy_listen_examples(dispatch_collection, sample_target, fn):
"def receive_%(event_name)s("
"%(named_event_arguments)s%(has_kw_arguments)s):\n"
" \"listen for the '%(event_name)s' event\"\n"
- "\n # ... (event handling logic) ...\n" % {
+ "\n # ... (event handling logic) ...\n"
+ % {
"since": since,
"event_name": fn.__name__,
"has_kw_arguments": " **kw"
- if dispatch_collection.has_kw else "",
+ if dispatch_collection.has_kw
+ else "",
"named_event_arguments": ", ".join(args),
- "sample_target": sample_target
+ "sample_target": sample_target,
}
)
return text
@@ -133,37 +140,34 @@ def _version_signature_changes(dispatch_collection):
" arguments ``%(named_event_arguments)s%(has_kw_arguments)s``.\n"
" Listener functions which accept the previous argument \n"
" signature(s) listed above will be automatically \n"
- " adapted to the new signature." % {
+ " adapted to the new signature."
+ % {
"since": since,
"event_name": dispatch_collection.name,
"named_event_arguments": ", ".join(dispatch_collection.arg_names),
- "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else ""
+ "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "",
}
)
def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn):
- header = ".. container:: event_signatures\n\n"\
- " Example argument forms::\n"\
+ header = (
+ ".. container:: event_signatures\n\n"
+ " Example argument forms::\n"
"\n"
+ )
sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj")
- text = (
- header +
- _indent(
- _standard_listen_example(
- dispatch_collection, sample_target, fn),
- " " * 8)
+ text = header + _indent(
+ _standard_listen_example(dispatch_collection, sample_target, fn),
+ " " * 8,
)
if dispatch_collection.legacy_signatures:
text += _indent(
- _legacy_listen_examples(
- dispatch_collection, sample_target, fn),
- " " * 8)
+ _legacy_listen_examples(dispatch_collection, sample_target, fn),
+ " " * 8,
+ )
text += _version_signature_changes(dispatch_collection)
- return util.inject_docstring_text(fn.__doc__,
- text,
- 1
- )
+ return util.inject_docstring_text(fn.__doc__, text, 1)
diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py
index 8d4bada0b..c862ae403 100644
--- a/lib/sqlalchemy/event/registry.py
+++ b/lib/sqlalchemy/event/registry.py
@@ -141,11 +141,15 @@ class _EventKey(object):
"""
__slots__ = (
- 'target', 'identifier', 'fn', 'fn_key', 'fn_wrap', 'dispatch_target'
+ "target",
+ "identifier",
+ "fn",
+ "fn_key",
+ "fn_wrap",
+ "dispatch_target",
)
- def __init__(self, target, identifier,
- fn, dispatch_target, _fn_wrap=None):
+ def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None):
self.target = target
self.identifier = identifier
self.fn = fn
@@ -169,7 +173,7 @@ class _EventKey(object):
self.identifier,
self.fn,
self.dispatch_target,
- _fn_wrap=fn_wrap
+ _fn_wrap=fn_wrap,
)
def with_dispatch_target(self, dispatch_target):
@@ -181,15 +185,18 @@ class _EventKey(object):
self.identifier,
self.fn,
dispatch_target,
- _fn_wrap=self.fn_wrap
+ _fn_wrap=self.fn_wrap,
)
def listen(self, *args, **kw):
once = kw.pop("once", False)
named = kw.pop("named", False)
- target, identifier, fn = \
- self.dispatch_target, self.identifier, self._listen_fn
+ target, identifier, fn = (
+ self.dispatch_target,
+ self.identifier,
+ self._listen_fn,
+ )
dispatch_collection = getattr(target.dispatch, identifier)
@@ -198,8 +205,9 @@ class _EventKey(object):
self = self.with_wrapper(adjusted_fn)
if once:
- self.with_wrapper(
- util.only_once(self._listen_fn)).listen(*args, **kw)
+ self.with_wrapper(util.only_once(self._listen_fn)).listen(
+ *args, **kw
+ )
else:
self.dispatch_target.dispatch._listen(self, *args, **kw)
@@ -208,8 +216,8 @@ class _EventKey(object):
if key not in _key_to_collection:
raise exc.InvalidRequestError(
- "No listeners found for event %s / %r / %s " %
- (self.target, self.identifier, self.fn)
+ "No listeners found for event %s / %r / %s "
+ % (self.target, self.identifier, self.fn)
)
dispatch_reg = _key_to_collection.pop(key)
@@ -224,20 +232,26 @@ class _EventKey(object):
"""
return self._key in _key_to_collection
- def base_listen(self, propagate=False, insert=False,
- named=False, retval=None):
+ def base_listen(
+ self, propagate=False, insert=False, named=False, retval=None
+ ):
- target, identifier, fn = \
- self.dispatch_target, self.identifier, self._listen_fn
+ target, identifier, fn = (
+ self.dispatch_target,
+ self.identifier,
+ self._listen_fn,
+ )
dispatch_collection = getattr(target.dispatch, identifier)
if insert:
- dispatch_collection.\
- for_modify(target.dispatch).insert(self, propagate)
+ dispatch_collection.for_modify(target.dispatch).insert(
+ self, propagate
+ )
else:
- dispatch_collection.\
- for_modify(target.dispatch).append(self, propagate)
+ dispatch_collection.for_modify(target.dispatch).append(
+ self, propagate
+ )
@property
def _listen_fn(self):
diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py
index 3e97ea896..fa62b7705 100644
--- a/lib/sqlalchemy/events.py
+++ b/lib/sqlalchemy/events.py
@@ -600,39 +600,53 @@ class ConnectionEvents(event.Events):
@classmethod
def _listen(cls, event_key, retval=False):
- target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, \
- event_key._listen_fn
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
target._has_events = True
if not retval:
- if identifier == 'before_execute':
+ if identifier == "before_execute":
orig_fn = fn
- def wrap_before_execute(conn, clauseelement,
- multiparams, params):
+ def wrap_before_execute(
+ conn, clauseelement, multiparams, params
+ ):
orig_fn(conn, clauseelement, multiparams, params)
return clauseelement, multiparams, params
+
fn = wrap_before_execute
- elif identifier == 'before_cursor_execute':
+ elif identifier == "before_cursor_execute":
orig_fn = fn
- def wrap_before_cursor_execute(conn, cursor, statement,
- parameters, context,
- executemany):
- orig_fn(conn, cursor, statement,
- parameters, context, executemany)
+ def wrap_before_cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ orig_fn(
+ conn,
+ cursor,
+ statement,
+ parameters,
+ context,
+ executemany,
+ )
return statement, parameters
+
fn = wrap_before_cursor_execute
- elif retval and \
- identifier not in ('before_execute',
- 'before_cursor_execute', 'handle_error'):
+ elif retval and identifier not in (
+ "before_execute",
+ "before_cursor_execute",
+ "handle_error",
+ ):
raise exc.ArgumentError(
"Only the 'before_execute', "
"'before_cursor_execute' and 'handle_error' engine "
"event listeners accept the 'retval=True' "
- "argument.")
+ "argument."
+ )
event_key.with_wrapper(fn).base_listen()
def before_execute(self, conn, clauseelement, multiparams, params):
@@ -677,8 +691,9 @@ class ConnectionEvents(event.Events):
"""
- def before_cursor_execute(self, conn, cursor, statement,
- parameters, context, executemany):
+ def before_cursor_execute(
+ self, conn, cursor, statement, parameters, context, executemany
+ ):
"""Intercept low-level cursor execute() events before execution,
receiving the string SQL statement and DBAPI-specific parameter list to
be invoked against a cursor.
@@ -718,8 +733,9 @@ class ConnectionEvents(event.Events):
"""
- def after_cursor_execute(self, conn, cursor, statement,
- parameters, context, executemany):
+ def after_cursor_execute(
+ self, conn, cursor, statement, parameters, context, executemany
+ ):
"""Intercept low-level cursor execute() events after execution.
:param conn: :class:`.Connection` object
@@ -737,8 +753,9 @@ class ConnectionEvents(event.Events):
"""
- def dbapi_error(self, conn, cursor, statement, parameters,
- context, exception):
+ def dbapi_error(
+ self, conn, cursor, statement, parameters, context, exception
+ ):
"""Intercept a raw DBAPI error.
This event is called with the DBAPI exception instance
@@ -1039,6 +1056,7 @@ class ConnectionEvents(event.Events):
.. versionadded:: 1.0.5
"""
+
def begin(self, conn):
"""Intercept begin() events.
@@ -1173,8 +1191,11 @@ class DialectEvents(event.Events):
@classmethod
def _listen(cls, event_key, retval=False):
- target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, event_key.fn
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key.fn,
+ )
target._has_events = True
event_key.base_listen()
@@ -1235,8 +1256,9 @@ class DialectEvents(event.Events):
"""
- def do_setinputsizes(self,
- inputsizes, cursor, statement, parameters, context):
+ def do_setinputsizes(
+ self, inputsizes, cursor, statement, parameters, context
+ ):
"""Receive the setinputsizes dictionary for possible modification.
This event is emitted in the case where the dialect makes use of the
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
index 40dcb7c55..832c5ee52 100644
--- a/lib/sqlalchemy/exc.py
+++ b/lib/sqlalchemy/exc.py
@@ -22,7 +22,7 @@ class SQLAlchemyError(Exception):
code = None
def __init__(self, *arg, **kw):
- code = kw.pop('code', None)
+ code = kw.pop("code", None)
if code is not None:
self.code = code
super(SQLAlchemyError, self).__init__(*arg, **kw)
@@ -33,7 +33,7 @@ class SQLAlchemyError(Exception):
else:
return (
"(Background on this error at: "
- "http://sqlalche.me/e/%s)" % (self.code, )
+ "http://sqlalche.me/e/%s)" % (self.code,)
)
def _message(self):
@@ -48,9 +48,7 @@ class SQLAlchemyError(Exception):
message = self._message()
if self.code:
- message = (
- "%s %s" % (message, self._code_str())
- )
+ message = "%s %s" % (message, self._code_str())
return message
@@ -112,6 +110,7 @@ class CircularDependencyError(SQLAlchemyError):
see :ref:`use_alter`.
"""
+
def __init__(self, message, cycles, edges, msg=None, code=None):
if msg is None:
message += " (%s)" % ", ".join(repr(s) for s in cycles)
@@ -122,8 +121,7 @@ class CircularDependencyError(SQLAlchemyError):
self.edges = edges
def __reduce__(self):
- return self.__class__, (None, self.cycles,
- self.edges, self.args[0])
+ return self.__class__, (None, self.cycles, self.edges, self.args[0])
class CompileError(SQLAlchemyError):
@@ -140,8 +138,9 @@ class UnsupportedCompilationError(CompileError):
def __init__(self, compiler, element_type):
super(UnsupportedCompilationError, self).__init__(
- "Compiler %r can't render element of type %s" %
- (compiler, element_type))
+ "Compiler %r can't render element of type %s"
+ % (compiler, element_type)
+ )
class IdentifierError(SQLAlchemyError):
@@ -158,6 +157,7 @@ class DisconnectionError(SQLAlchemyError):
regarding the connection attempt.
"""
+
invalidate_pool = False
@@ -175,6 +175,7 @@ class InvalidatePoolError(DisconnectionError):
.. versionadded:: 1.2
"""
+
invalidate_pool = True
@@ -213,6 +214,7 @@ class NoReferencedTableError(NoReferenceError):
located.
"""
+
def __init__(self, message, tname):
NoReferenceError.__init__(self, message)
self.table_name = tname
@@ -226,14 +228,17 @@ class NoReferencedColumnError(NoReferenceError):
located.
"""
+
def __init__(self, message, tname, cname):
NoReferenceError.__init__(self, message)
self.table_name = tname
self.column_name = cname
def __reduce__(self):
- return self.__class__, (self.args[0], self.table_name,
- self.column_name)
+ return (
+ self.__class__,
+ (self.args[0], self.table_name, self.column_name),
+ )
class NoSuchTableError(InvalidRequestError):
@@ -273,6 +278,7 @@ class DontWrapMixin(object):
"""
+
# Moved to orm.exc; compatibility definition installed by orm import until 0.6
UnmappedColumnError = None
@@ -310,8 +316,10 @@ class StatementError(SQLAlchemyError):
self.detail.append(msg)
def __reduce__(self):
- return self.__class__, (self.args[0], self.statement,
- self.params, self.orig)
+ return (
+ self.__class__,
+ (self.args[0], self.statement, self.params, self.orig),
+ )
def __str__(self):
from sqlalchemy.sql import util
@@ -325,9 +333,7 @@ class StatementError(SQLAlchemyError):
code_str = self._code_str()
if code_str:
details.append(code_str)
- return ' '.join([
- "(%s)" % det for det in self.detail
- ] + details)
+ return " ".join(["(%s)" % det for det in self.detail] + details)
class DBAPIError(StatementError):
@@ -353,18 +359,23 @@ class DBAPIError(StatementError):
"""
- code = 'dbapi'
+ code = "dbapi"
@classmethod
- def instance(cls, statement, params,
- orig, dbapi_base_err,
- connection_invalidated=False,
- dialect=None):
+ def instance(
+ cls,
+ statement,
+ params,
+ orig,
+ dbapi_base_err,
+ connection_invalidated=False,
+ dialect=None,
+ ):
# Don't ever wrap these, just return them directly as if
# DBAPIError didn't exist.
- if (isinstance(orig, BaseException) and
- not isinstance(orig, Exception)) or \
- isinstance(orig, DontWrapMixin):
+ if (
+ isinstance(orig, BaseException) and not isinstance(orig, Exception)
+ ) or isinstance(orig, DontWrapMixin):
return orig
if orig is not None:
@@ -372,17 +383,28 @@ class DBAPIError(StatementError):
# raise a StatementError
if isinstance(orig, SQLAlchemyError) and statement:
return StatementError(
- "(%s.%s) %s" %
- (orig.__class__.__module__, orig.__class__.__name__,
- orig.args[0]),
- statement, params, orig, code=orig.code
+ "(%s.%s) %s"
+ % (
+ orig.__class__.__module__,
+ orig.__class__.__name__,
+ orig.args[0],
+ ),
+ statement,
+ params,
+ orig,
+ code=orig.code,
)
elif not isinstance(orig, dbapi_base_err) and statement:
return StatementError(
- "(%s.%s) %s" %
- (orig.__class__.__module__, orig.__class__.__name__,
- orig),
- statement, params, orig
+ "(%s.%s) %s"
+ % (
+ orig.__class__.__module__,
+ orig.__class__.__name__,
+ orig,
+ ),
+ statement,
+ params,
+ orig,
)
glob = globals()
@@ -390,31 +412,42 @@ class DBAPIError(StatementError):
name = super_.__name__
if dialect:
name = dialect.dbapi_exception_translation_map.get(
- name, name)
+ name, name
+ )
if name in glob and issubclass(glob[name], DBAPIError):
cls = glob[name]
break
- return cls(statement, params, orig, connection_invalidated,
- code=cls.code)
+ return cls(
+ statement, params, orig, connection_invalidated, code=cls.code
+ )
def __reduce__(self):
- return self.__class__, (self.statement, self.params,
- self.orig, self.connection_invalidated)
+ return (
+ self.__class__,
+ (
+ self.statement,
+ self.params,
+ self.orig,
+ self.connection_invalidated,
+ ),
+ )
- def __init__(self, statement, params, orig, connection_invalidated=False,
- code=None):
+ def __init__(
+ self, statement, params, orig, connection_invalidated=False, code=None
+ ):
try:
text = str(orig)
except Exception as e:
- text = 'Error in str() of DB-API-generated exception: ' + str(e)
+ text = "Error in str() of DB-API-generated exception: " + str(e)
StatementError.__init__(
self,
- '(%s.%s) %s' % (
- orig.__class__.__module__, orig.__class__.__name__, text, ),
+ "(%s.%s) %s"
+ % (orig.__class__.__module__, orig.__class__.__name__, text),
statement,
params,
- orig, code=code
+ orig,
+ code=code,
)
self.connection_invalidated = connection_invalidated
@@ -466,8 +499,10 @@ class NotSupportedError(DatabaseError):
code = "tw8g"
+
# Warnings
+
class SADeprecationWarning(DeprecationWarning):
"""Issued once per usage of a deprecated API."""
diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py
index 9558b2a1f..9fed09e2b 100644
--- a/lib/sqlalchemy/ext/__init__.py
+++ b/lib/sqlalchemy/ext/__init__.py
@@ -8,4 +8,3 @@
from .. import util as _sa_util
_sa_util.dependencies.resolve_all("sqlalchemy.ext")
-
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index ff9433d4d..56b91ce0b 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -76,7 +76,7 @@ def association_proxy(target_collection, attr, **kw):
return AssociationProxy(target_collection, attr, **kw)
-ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY')
+ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.AssociationProxy`.
@@ -92,10 +92,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
is_attribute = False
extension_type = ASSOCIATION_PROXY
- def __init__(self, target_collection, attr, creator=None,
- getset_factory=None, proxy_factory=None,
- proxy_bulk_set=None, info=None,
- cascade_scalar_deletes=False):
+ def __init__(
+ self,
+ target_collection,
+ attr,
+ creator=None,
+ getset_factory=None,
+ proxy_factory=None,
+ proxy_bulk_set=None,
+ info=None,
+ cascade_scalar_deletes=False,
+ ):
"""Construct a new :class:`.AssociationProxy`.
The :func:`.association_proxy` function is provided as the usual
@@ -162,8 +169,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
self.proxy_bulk_set = proxy_bulk_set
self.cascade_scalar_deletes = cascade_scalar_deletes
- self.key = '_%s_%s_%s' % (
- type(self).__name__, target_collection, id(self))
+ self.key = "_%s_%s_%s" % (
+ type(self).__name__,
+ target_collection,
+ id(self),
+ )
if info:
self.info = info
@@ -264,12 +274,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
def getter(target):
return _getter(target) if target is not None else None
+
if collection_class is dict:
+
def setter(o, k, v):
setattr(o, attr, v)
+
else:
+
def setter(o, v):
setattr(o, attr, v)
+
return getter, setter
@@ -325,20 +340,21 @@ class AssociationProxyInstance(object):
def for_proxy(cls, parent, owning_class, parent_instance):
target_collection = parent.target_collection
value_attr = parent.value_attr
- prop = orm.class_mapper(owning_class).\
- get_property(target_collection)
+ prop = orm.class_mapper(owning_class).get_property(target_collection)
# this was never asserted before but this should be made clear.
if not isinstance(prop, orm.RelationshipProperty):
raise NotImplementedError(
"association proxy to a non-relationship "
- "intermediary is not supported")
+ "intermediary is not supported"
+ )
target_class = prop.mapper.class_
try:
target_assoc = cls._cls_unwrap_target_assoc_proxy(
- target_class, value_attr)
+ target_class, value_attr
+ )
except AttributeError:
# the proxied attribute doesn't exist on the target class;
# return an "ambiguous" instance that will work on a per-object
@@ -353,8 +369,8 @@ class AssociationProxyInstance(object):
@classmethod
def _construct_for_assoc(
- cls, target_assoc, parent, owning_class,
- target_class, value_attr):
+ cls, target_assoc, parent, owning_class, target_class, value_attr
+ ):
if target_assoc is not None:
return ObjectAssociationProxyInstance(
parent, owning_class, target_class, value_attr
@@ -371,8 +387,9 @@ class AssociationProxyInstance(object):
)
def _get_property(self):
- return orm.class_mapper(self.owning_class).\
- get_property(self.target_collection)
+ return orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
@property
def _comparator(self):
@@ -388,7 +405,8 @@ class AssociationProxyInstance(object):
@util.memoized_property
def _unwrap_target_assoc_proxy(self):
return self._cls_unwrap_target_assoc_proxy(
- self.target_class, self.value_attr)
+ self.target_class, self.value_attr
+ )
@property
def remote_attr(self):
@@ -448,8 +466,11 @@ class AssociationProxyInstance(object):
@util.memoized_property
def _value_is_scalar(self):
- return not self._get_property().\
- mapper.get_property(self.value_attr).uselist
+ return (
+ not self._get_property()
+ .mapper.get_property(self.value_attr)
+ .uselist
+ )
@property
def _target_is_object(self):
@@ -468,12 +489,17 @@ class AssociationProxyInstance(object):
def getter(target):
return _getter(target) if target is not None else None
+
if collection_class is dict:
+
def setter(o, k, v):
return setattr(o, attr, v)
+
else:
+
def setter(o, v):
return setattr(o, attr, v)
+
return getter, setter
@property
@@ -500,14 +526,18 @@ class AssociationProxyInstance(object):
return proxy
self.collection_class, proxy = self._new(
- _lazy_collection(obj, self.target_collection))
+ _lazy_collection(obj, self.target_collection)
+ )
setattr(obj, self.key, (id(obj), id(self), proxy))
return proxy
def set(self, obj, values):
if self.scalar:
- creator = self.parent.creator \
- if self.parent.creator else self.target_class
+ creator = (
+ self.parent.creator
+ if self.parent.creator
+ else self.target_class
+ )
target = getattr(obj, self.target_collection)
if target is None:
if values is None:
@@ -535,35 +565,52 @@ class AssociationProxyInstance(object):
delattr(obj, self.target_collection)
def _new(self, lazy_collection):
- creator = self.parent.creator if self.parent.creator else \
- self.target_class
+ creator = (
+ self.parent.creator if self.parent.creator else self.target_class
+ )
collection_class = util.duck_type_collection(lazy_collection())
if self.parent.proxy_factory:
- return collection_class, self.parent.proxy_factory(
- lazy_collection, creator, self.value_attr, self)
+ return (
+ collection_class,
+ self.parent.proxy_factory(
+ lazy_collection, creator, self.value_attr, self
+ ),
+ )
if self.parent.getset_factory:
- getter, setter = self.parent.getset_factory(
- collection_class, self)
+ getter, setter = self.parent.getset_factory(collection_class, self)
else:
getter, setter = self.parent._default_getset(collection_class)
if collection_class is list:
- return collection_class, _AssociationList(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationList(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
elif collection_class is dict:
- return collection_class, _AssociationDict(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationDict(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
elif collection_class is set:
- return collection_class, _AssociationSet(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationSet(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
else:
raise exc.ArgumentError(
- 'could not guess which interface to use for '
+ "could not guess which interface to use for "
'collection_class "%s" backing "%s"; specify a '
- 'proxy_factory and proxy_bulk_set manually' %
- (self.collection_class.__name__, self.target_collection))
+ "proxy_factory and proxy_bulk_set manually"
+ % (self.collection_class.__name__, self.target_collection)
+ )
def _set(self, proxy, values):
if self.parent.proxy_bulk_set:
@@ -576,16 +623,19 @@ class AssociationProxyInstance(object):
proxy.update(values)
else:
raise exc.ArgumentError(
- 'no proxy_bulk_set supplied for custom '
- 'collection_class implementation')
+ "no proxy_bulk_set supplied for custom "
+ "collection_class implementation"
+ )
def _inflate(self, proxy):
- creator = self.parent.creator and \
- self.parent.creator or self.target_class
+ creator = (
+ self.parent.creator and self.parent.creator or self.target_class
+ )
if self.parent.getset_factory:
getter, setter = self.parent.getset_factory(
- self.collection_class, self)
+ self.collection_class, self
+ )
else:
getter, setter = self.parent._default_getset(self.collection_class)
@@ -594,12 +644,13 @@ class AssociationProxyInstance(object):
proxy.setter = setter
def _criterion_exists(self, criterion=None, **kwargs):
- is_has = kwargs.pop('is_has', None)
+ is_has = kwargs.pop("is_has", None)
target_assoc = self._unwrap_target_assoc_proxy
if target_assoc is not None:
inner = target_assoc._criterion_exists(
- criterion=criterion, **kwargs)
+ criterion=criterion, **kwargs
+ )
return self._comparator._criterion_exists(inner)
if self._target_is_object:
@@ -631,15 +682,15 @@ class AssociationProxyInstance(object):
"""
if self._unwrap_target_assoc_proxy is None and (
- self.scalar and (
- not self._target_is_object or self._value_is_scalar)
+ self.scalar
+ and (not self._target_is_object or self._value_is_scalar)
):
raise exc.InvalidRequestError(
- "'any()' not implemented for scalar "
- "attributes. Use has()."
+ "'any()' not implemented for scalar " "attributes. Use has()."
)
return self._criterion_exists(
- criterion=criterion, is_has=False, **kwargs)
+ criterion=criterion, is_has=False, **kwargs
+ )
def has(self, criterion=None, **kwargs):
"""Produce a proxied 'has' expression using EXISTS.
@@ -651,14 +702,15 @@ class AssociationProxyInstance(object):
"""
if self._unwrap_target_assoc_proxy is None and (
- not self.scalar or (
- self._target_is_object and not self._value_is_scalar)
+ not self.scalar
+ or (self._target_is_object and not self._value_is_scalar)
):
raise exc.InvalidRequestError(
- "'has()' not implemented for collections. "
- "Use any().")
+ "'has()' not implemented for collections. " "Use any()."
+ )
return self._criterion_exists(
- criterion=criterion, is_has=True, **kwargs)
+ criterion=criterion, is_has=True, **kwargs
+ )
class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
@@ -673,10 +725,14 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
"Association proxy %s.%s refers to an attribute '%s' that is not "
"directly mapped on class %s; therefore this operation cannot "
"proceed since we don't know what type of object is referred "
- "towards" % (
- self.owning_class.__name__, self.target_collection,
- self.value_attr, self.target_class
- ))
+ "towards"
+ % (
+ self.owning_class.__name__,
+ self.target_collection,
+ self.value_attr,
+ self.target_class,
+ )
+ )
def get(self, obj):
self._ambiguous()
@@ -718,27 +774,32 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
return self
def _populate_cache(self, instance_class):
- prop = orm.class_mapper(self.owning_class).\
- get_property(self.target_collection)
+ prop = orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
if inspect(instance_class).mapper.isa(prop.mapper):
target_class = instance_class
try:
target_assoc = self._cls_unwrap_target_assoc_proxy(
- target_class, self.value_attr)
+ target_class, self.value_attr
+ )
except AttributeError:
pass
else:
- self._lookup_cache[instance_class] = \
- self._construct_for_assoc(
- target_assoc, self.parent, self.owning_class,
- target_class, self.value_attr
+ self._lookup_cache[instance_class] = self._construct_for_assoc(
+ target_assoc,
+ self.parent,
+ self.owning_class,
+ target_class,
+ self.value_attr,
)
class ObjectAssociationProxyInstance(AssociationProxyInstance):
"""an :class:`.AssociationProxyInstance` that has an object as a target.
"""
+
_target_is_object = True
_is_canonical = True
@@ -756,17 +817,21 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
if target_assoc is not None:
return self._comparator._criterion_exists(
target_assoc.contains(obj)
- if not target_assoc.scalar else target_assoc == obj
+ if not target_assoc.scalar
+ else target_assoc == obj
)
- elif self._target_is_object and self.scalar and \
- not self._value_is_scalar:
+ elif (
+ self._target_is_object
+ and self.scalar
+ and not self._value_is_scalar
+ ):
return self._comparator.has(
getattr(self.target_class, self.value_attr).contains(obj)
)
- elif self._target_is_object and self.scalar and \
- self._value_is_scalar:
+ elif self._target_is_object and self.scalar and self._value_is_scalar:
raise exc.InvalidRequestError(
- "contains() doesn't apply to a scalar object endpoint; use ==")
+ "contains() doesn't apply to a scalar object endpoint; use =="
+ )
else:
return self._comparator._criterion_exists(**{self.value_attr: obj})
@@ -777,7 +842,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
if obj is None:
return or_(
self._comparator.has(**{self.value_attr: obj}),
- self._comparator == None
+ self._comparator == None,
)
else:
return self._comparator.has(**{self.value_attr: obj})
@@ -786,14 +851,17 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
# note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
return self._comparator.has(
- getattr(self.target_class, self.value_attr) != obj)
+ getattr(self.target_class, self.value_attr) != obj
+ )
class ColumnAssociationProxyInstance(
- ColumnOperators, AssociationProxyInstance):
+ ColumnOperators, AssociationProxyInstance
+):
"""an :class:`.AssociationProxyInstance` that has a database column as a
target.
"""
+
_target_is_object = False
_is_canonical = True
@@ -803,9 +871,7 @@ class ColumnAssociationProxyInstance(
self.remote_attr.operate(operator.eq, other)
)
if other is None:
- return or_(
- expr, self._comparator == None
- )
+ return or_(expr, self._comparator == None)
else:
return expr
@@ -824,11 +890,11 @@ class _lazy_collection(object):
return getattr(self.parent, self.target)
def __getstate__(self):
- return {'obj': self.parent, 'target': self.target}
+ return {"obj": self.parent, "target": self.target}
def __setstate__(self, state):
- self.parent = state['obj']
- self.target = state['target']
+ self.parent = state["obj"]
+ self.target = state["target"]
class _AssociationCollection(object):
@@ -874,11 +940,11 @@ class _AssociationCollection(object):
__nonzero__ = __bool__
def __getstate__(self):
- return {'parent': self.parent, 'lazy_collection': self.lazy_collection}
+ return {"parent": self.parent, "lazy_collection": self.lazy_collection}
def __setstate__(self, state):
- self.parent = state['parent']
- self.lazy_collection = state['lazy_collection']
+ self.parent = state["parent"]
+ self.lazy_collection = state["lazy_collection"]
self.parent._inflate(self)
@@ -925,8 +991,8 @@ class _AssociationList(_AssociationCollection):
if len(value) != len(rng):
raise ValueError(
"attempt to assign sequence of size %s to "
- "extended slice of size %s" % (len(value),
- len(rng)))
+ "extended slice of size %s" % (len(value), len(rng))
+ )
for i, item in zip(rng, value):
self._set(self.col[i], item)
@@ -968,8 +1034,14 @@ class _AssociationList(_AssociationCollection):
col.append(item)
def count(self, value):
- return sum([1 for _ in
- util.itertools_filter(lambda v: v == value, iter(self))])
+ return sum(
+ [
+ 1
+ for _ in util.itertools_filter(
+ lambda v: v == value, iter(self)
+ )
+ ]
+ )
def extend(self, values):
for v in values:
@@ -999,7 +1071,7 @@ class _AssociationList(_AssociationCollection):
raise NotImplementedError
def clear(self):
- del self.col[0:len(self.col)]
+ del self.col[0 : len(self.col)]
def __eq__(self, other):
return list(self) == other
@@ -1040,6 +1112,7 @@ class _AssociationList(_AssociationCollection):
if not isinstance(n, int):
return NotImplemented
return list(self) * n
+
__rmul__ = __mul__
def __iadd__(self, iterable):
@@ -1072,13 +1145,17 @@ class _AssociationList(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(list, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
-_NotProvided = util.symbol('_NotProvided')
+_NotProvided = util.symbol("_NotProvided")
class _AssociationDict(_AssociationCollection):
@@ -1160,6 +1237,7 @@ class _AssociationDict(_AssociationCollection):
return self.col.keys()
if util.py2k:
+
def iteritems(self):
return ((key, self._get(self.col[key])) for key in self.col)
@@ -1174,7 +1252,9 @@ class _AssociationDict(_AssociationCollection):
def items(self):
return [(k, self._get(self.col[k])) for k in self]
+
else:
+
def items(self):
return ((key, self._get(self.col[key])) for key in self.col)
@@ -1194,14 +1274,15 @@ class _AssociationDict(_AssociationCollection):
def update(self, *a, **kw):
if len(a) > 1:
- raise TypeError('update expected at most 1 arguments, got %i' %
- len(a))
+ raise TypeError(
+ "update expected at most 1 arguments, got %i" % len(a)
+ )
elif len(a) == 1:
seq_or_map = a[0]
# discern dict from sequence - took the advice from
# http://www.voidspace.org.uk/python/articles/duck_typing.shtml
# still not perfect :(
- if hasattr(seq_or_map, 'keys'):
+ if hasattr(seq_or_map, "keys"):
for item in seq_or_map:
self[item] = seq_or_map[item]
else:
@@ -1211,7 +1292,8 @@ class _AssociationDict(_AssociationCollection):
except ValueError:
raise ValueError(
"dictionary update sequence "
- "requires 2-element tuples")
+ "requires 2-element tuples"
+ )
for key, value in kw:
self[key] = value
@@ -1223,8 +1305,12 @@ class _AssociationDict(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(dict, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(dict, func_name)
+ ):
func.__doc__ = getattr(dict, func_name).__doc__
del func_name, func
@@ -1288,7 +1374,7 @@ class _AssociationSet(_AssociationCollection):
def pop(self):
if not self.col:
- raise KeyError('pop from an empty set')
+ raise KeyError("pop from an empty set")
member = self.col.pop()
return self._get(member)
@@ -1420,7 +1506,11 @@ class _AssociationSet(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(set, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(set, func_name)
+ ):
func.__doc__ = getattr(set, func_name).__doc__
del func_name, func
diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py
index cafb3d61c..747373a2a 100644
--- a/lib/sqlalchemy/ext/automap.py
+++ b/lib/sqlalchemy/ext/automap.py
@@ -580,7 +580,8 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
def name_for_collection_relationship(
- base, local_cls, referred_cls, constraint):
+ base, local_cls, referred_cls, constraint
+):
"""Return the attribute name that should be used to refer from one
class to another, for a collection reference.
@@ -607,7 +608,8 @@ def name_for_collection_relationship(
def generate_relationship(
- base, direction, return_fn, attrname, local_cls, referred_cls, **kw):
+ base, direction, return_fn, attrname, local_cls, referred_cls, **kw
+):
r"""Generate a :func:`.relationship` or :func:`.backref` on behalf of two
mapped classes.
@@ -677,6 +679,7 @@ class AutomapBase(object):
:ref:`automap_toplevel`
"""
+
__abstract__ = True
classes = None
@@ -694,15 +697,16 @@ class AutomapBase(object):
@classmethod
def prepare(
- cls,
- engine=None,
- reflect=False,
- schema=None,
- classname_for_table=classname_for_table,
- collection_class=list,
- name_for_scalar_relationship=name_for_scalar_relationship,
- name_for_collection_relationship=name_for_collection_relationship,
- generate_relationship=generate_relationship):
+ cls,
+ engine=None,
+ reflect=False,
+ schema=None,
+ classname_for_table=classname_for_table,
+ collection_class=list,
+ name_for_scalar_relationship=name_for_scalar_relationship,
+ name_for_collection_relationship=name_for_collection_relationship,
+ generate_relationship=generate_relationship,
+ ):
"""Extract mapped classes and relationships from the :class:`.MetaData` and
perform mappings.
@@ -752,15 +756,16 @@ class AutomapBase(object):
engine,
schema=schema,
extend_existing=True,
- autoload_replace=False
+ autoload_replace=False,
)
_CONFIGURE_MUTEX.acquire()
try:
table_to_map_config = dict(
(m.local_table, m)
- for m in _DeferredMapperConfig.
- classes_for_base(cls, sort=False)
+ for m in _DeferredMapperConfig.classes_for_base(
+ cls, sort=False
+ )
)
many_to_many = []
@@ -774,30 +779,39 @@ class AutomapBase(object):
elif table not in table_to_map_config:
mapped_cls = type(
classname_for_table(cls, table.name, table),
- (cls, ),
- {"__table__": table}
+ (cls,),
+ {"__table__": table},
)
map_config = _DeferredMapperConfig.config_for_cls(
- mapped_cls)
+ mapped_cls
+ )
cls.classes[map_config.cls.__name__] = mapped_cls
table_to_map_config[table] = map_config
for map_config in table_to_map_config.values():
- _relationships_for_fks(cls,
- map_config,
- table_to_map_config,
- collection_class,
- name_for_scalar_relationship,
- name_for_collection_relationship,
- generate_relationship)
+ _relationships_for_fks(
+ cls,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
for lcl_m2m, rem_m2m, m2m_const, table in many_to_many:
- _m2m_relationship(cls, lcl_m2m, rem_m2m, m2m_const, table,
- table_to_map_config,
- collection_class,
- name_for_scalar_relationship,
- name_for_collection_relationship,
- generate_relationship)
+ _m2m_relationship(
+ cls,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
for map_config in _DeferredMapperConfig.classes_for_base(cls):
map_config.map()
@@ -853,20 +867,27 @@ def automap_base(declarative_base=None, **kw):
return type(
Base.__name__,
- (AutomapBase, Base,),
- {"__abstract__": True, "classes": util.Properties({})}
+ (AutomapBase, Base),
+ {"__abstract__": True, "classes": util.Properties({})},
)
def _is_many_to_many(automap_base, table):
- fk_constraints = [const for const in table.constraints
- if isinstance(const, ForeignKeyConstraint)]
+ fk_constraints = [
+ const
+ for const in table.constraints
+ if isinstance(const, ForeignKeyConstraint)
+ ]
if len(fk_constraints) != 2:
return None, None, None
cols = sum(
- [[fk.parent for fk in fk_constraint.elements]
- for fk_constraint in fk_constraints], [])
+ [
+ [fk.parent for fk in fk_constraint.elements]
+ for fk_constraint in fk_constraints
+ ],
+ [],
+ )
if set(cols) != set(table.c):
return None, None, None
@@ -874,15 +895,19 @@ def _is_many_to_many(automap_base, table):
return (
fk_constraints[0].elements[0].column.table,
fk_constraints[1].elements[0].column.table,
- fk_constraints
+ fk_constraints,
)
-def _relationships_for_fks(automap_base, map_config, table_to_map_config,
- collection_class,
- name_for_scalar_relationship,
- name_for_collection_relationship,
- generate_relationship):
+def _relationships_for_fks(
+ automap_base,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
local_table = map_config.local_table
local_cls = map_config.cls # derived from a weakref, may be None
@@ -898,32 +923,33 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config,
referred_cls = referred_cfg.cls
if local_cls is not referred_cls and issubclass(
- local_cls, referred_cls):
+ local_cls, referred_cls
+ ):
continue
relationship_name = name_for_scalar_relationship(
- automap_base,
- local_cls,
- referred_cls, constraint)
+ automap_base, local_cls, referred_cls, constraint
+ )
backref_name = name_for_collection_relationship(
- automap_base,
- referred_cls,
- local_cls,
- constraint
+ automap_base, referred_cls, local_cls, constraint
)
o2m_kws = {}
nullable = False not in {fk.parent.nullable for fk in fks}
if not nullable:
- o2m_kws['cascade'] = "all, delete-orphan"
+ o2m_kws["cascade"] = "all, delete-orphan"
- if constraint.ondelete and \
- constraint.ondelete.lower() == "cascade":
- o2m_kws['passive_deletes'] = True
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "cascade"
+ ):
+ o2m_kws["passive_deletes"] = True
else:
- if constraint.ondelete and \
- constraint.ondelete.lower() == "set null":
- o2m_kws['passive_deletes'] = True
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "set null"
+ ):
+ o2m_kws["passive_deletes"] = True
create_backref = backref_name not in referred_cfg.properties
@@ -931,54 +957,65 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config,
if create_backref:
backref_obj = generate_relationship(
automap_base,
- interfaces.ONETOMANY, backref,
- backref_name, referred_cls, local_cls,
+ interfaces.ONETOMANY,
+ backref,
+ backref_name,
+ referred_cls,
+ local_cls,
collection_class=collection_class,
- **o2m_kws)
+ **o2m_kws
+ )
else:
backref_obj = None
- rel = generate_relationship(automap_base,
- interfaces.MANYTOONE,
- relationship,
- relationship_name,
- local_cls, referred_cls,
- foreign_keys=[
- fk.parent
- for fk in constraint.elements],
- backref=backref_obj,
- remote_side=[
- fk.column
- for fk in constraint.elements]
- )
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOONE,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ backref=backref_obj,
+ remote_side=[fk.column for fk in constraint.elements],
+ )
if rel is not None:
map_config.properties[relationship_name] = rel
if not create_backref:
referred_cfg.properties[
- backref_name].back_populates = relationship_name
+ backref_name
+ ].back_populates = relationship_name
elif create_backref:
- rel = generate_relationship(automap_base,
- interfaces.ONETOMANY,
- relationship,
- backref_name,
- referred_cls, local_cls,
- foreign_keys=[
- fk.parent
- for fk in constraint.elements],
- back_populates=relationship_name,
- collection_class=collection_class,
- **o2m_kws)
+ rel = generate_relationship(
+ automap_base,
+ interfaces.ONETOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ **o2m_kws
+ )
if rel is not None:
referred_cfg.properties[backref_name] = rel
map_config.properties[
- relationship_name].back_populates = backref_name
-
-
-def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table,
- table_to_map_config,
- collection_class,
- name_for_scalar_relationship,
- name_for_collection_relationship,
- generate_relationship):
+ relationship_name
+ ].back_populates = backref_name
+
+
+def _m2m_relationship(
+ automap_base,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
map_config = table_to_map_config.get(lcl_m2m, None)
referred_cfg = table_to_map_config.get(rem_m2m, None)
@@ -989,14 +1026,10 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table,
referred_cls = referred_cfg.cls
relationship_name = name_for_collection_relationship(
- automap_base,
- local_cls,
- referred_cls, m2m_const[0])
+ automap_base, local_cls, referred_cls, m2m_const[0]
+ )
backref_name = name_for_collection_relationship(
- automap_base,
- referred_cls,
- local_cls,
- m2m_const[1]
+ automap_base, referred_cls, local_cls, m2m_const[1]
)
create_backref = backref_name not in referred_cfg.properties
@@ -1008,48 +1041,56 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table,
interfaces.MANYTOMANY,
backref,
backref_name,
- referred_cls, local_cls,
- collection_class=collection_class
+ referred_cls,
+ local_cls,
+ collection_class=collection_class,
)
else:
backref_obj = None
- rel = generate_relationship(automap_base,
- interfaces.MANYTOMANY,
- relationship,
- relationship_name,
- local_cls, referred_cls,
- secondary=table,
- primaryjoin=and_(
- fk.column == fk.parent
- for fk in m2m_const[0].elements),
- secondaryjoin=and_(
- fk.column == fk.parent
- for fk in m2m_const[1].elements),
- backref=backref_obj,
- collection_class=collection_class
- )
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ backref=backref_obj,
+ collection_class=collection_class,
+ )
if rel is not None:
map_config.properties[relationship_name] = rel
if not create_backref:
referred_cfg.properties[
- backref_name].back_populates = relationship_name
+ backref_name
+ ].back_populates = relationship_name
elif create_backref:
- rel = generate_relationship(automap_base,
- interfaces.MANYTOMANY,
- relationship,
- backref_name,
- referred_cls, local_cls,
- secondary=table,
- primaryjoin=and_(
- fk.column == fk.parent
- for fk in m2m_const[1].elements),
- secondaryjoin=and_(
- fk.column == fk.parent
- for fk in m2m_const[0].elements),
- back_populates=relationship_name,
- collection_class=collection_class)
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ )
if rel is not None:
referred_cfg.properties[backref_name] = rel
map_config.properties[
- relationship_name].back_populates = backref_name
+ relationship_name
+ ].back_populates = backref_name
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
index 516879142..f55231a09 100644
--- a/lib/sqlalchemy/ext/baked.py
+++ b/lib/sqlalchemy/ext/baked.py
@@ -38,7 +38,8 @@ class Bakery(object):
"""
- __slots__ = 'cls', 'cache'
+
+ __slots__ = "cls", "cache"
def __init__(self, cls_, cache):
self.cls = cls_
@@ -51,7 +52,7 @@ class Bakery(object):
class BakedQuery(object):
"""A builder object for :class:`.query.Query` objects."""
- __slots__ = 'steps', '_bakery', '_cache_key', '_spoiled'
+ __slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
def __init__(self, bakery, initial_fn, args=()):
self._cache_key = ()
@@ -148,7 +149,7 @@ class BakedQuery(object):
"""
if not full and not self._spoiled:
_spoil_point = self._clone()
- _spoil_point._cache_key += ('_query_only', )
+ _spoil_point._cache_key += ("_query_only",)
self.steps = [_spoil_point._retrieve_baked_query]
self._spoiled = True
return self
@@ -164,7 +165,7 @@ class BakedQuery(object):
session will want to use.
"""
- return self._cache_key + (session._query_cls, )
+ return self._cache_key + (session._query_cls,)
def _with_lazyload_options(self, options, effective_path, cache_path=None):
"""Cloning version of _add_lazyload_options.
@@ -201,16 +202,20 @@ class BakedQuery(object):
key += cache_key
self.add_criteria(
- lambda q: q._with_current_path(effective_path).
- _conditional_options(*options),
- cache_path.path, key
+ lambda q: q._with_current_path(
+ effective_path
+ )._conditional_options(*options),
+ cache_path.path,
+ key,
)
def _retrieve_baked_query(self, session):
query = self._bakery.get(self._effective_key(session), None)
if query is None:
query = self._as_query(session)
- self._bakery[self._effective_key(session)] = query.with_session(None)
+ self._bakery[self._effective_key(session)] = query.with_session(
+ None
+ )
return query.with_session(session)
def _bake(self, session):
@@ -227,8 +232,12 @@ class BakedQuery(object):
# so delete some compilation-use-only attributes that can take up
# space
for attr in (
- '_correlate', '_from_obj', '_mapper_adapter_map',
- '_joinpath', '_joinpoint'):
+ "_correlate",
+ "_from_obj",
+ "_mapper_adapter_map",
+ "_joinpath",
+ "_joinpoint",
+ ):
query.__dict__.pop(attr, None)
self._bakery[self._effective_key(session)] = context
return context
@@ -276,11 +285,13 @@ class BakedQuery(object):
session = query_or_session.session
if session is None:
raise sa_exc.ArgumentError(
- "Given Query needs to be associated with a Session")
+ "Given Query needs to be associated with a Session"
+ )
else:
raise TypeError(
- "Query or Session object expected, got %r." %
- type(query_or_session))
+ "Query or Session object expected, got %r."
+ % type(query_or_session)
+ )
return self._as_query(session)
def _as_query(self, session):
@@ -299,10 +310,10 @@ class BakedQuery(object):
a "baked" query so that we save on performance too.
"""
- context.attributes['baked_queries'] = baked_queries = []
+ context.attributes["baked_queries"] = baked_queries = []
for k, v in list(context.attributes.items()):
if isinstance(v, Query):
- if 'subquery' in k:
+ if "subquery" in k:
bk = BakedQuery(self._bakery, lambda *args: v)
bk._cache_key = self._cache_key + k
bk._bake(session)
@@ -310,15 +321,17 @@ class BakedQuery(object):
del context.attributes[k]
def _unbake_subquery_loaders(
- self, session, context, params, post_criteria):
+ self, session, context, params, post_criteria
+ ):
"""Retrieve subquery eager loaders stored by _bake_subquery_loaders
and turn them back into Result objects that will iterate just
like a Query object.
"""
for k, cache_key, query in context.attributes["baked_queries"]:
- bk = BakedQuery(self._bakery,
- lambda sess, q=query: q.with_session(sess))
+ bk = BakedQuery(
+ self._bakery, lambda sess, q=query: q.with_session(sess)
+ )
bk._cache_key = cache_key
q = bk.for_session(session)
for fn in post_criteria:
@@ -334,7 +347,8 @@ class Result(object):
against a target :class:`.Session`, and is then invoked for results.
"""
- __slots__ = 'bq', 'session', '_params', '_post_criteria'
+
+ __slots__ = "bq", "session", "_params", "_post_criteria"
def __init__(self, bq, session):
self.bq = bq
@@ -350,7 +364,8 @@ class Result(object):
elif len(args) > 0:
raise sa_exc.ArgumentError(
"params() takes zero or one positional argument, "
- "which is a dictionary.")
+ "which is a dictionary."
+ )
self._params.update(kw)
return self
@@ -403,7 +418,8 @@ class Result(object):
context.attributes = context.attributes.copy()
bq._unbake_subquery_loaders(
- self.session, context, self._params, self._post_criteria)
+ self.session, context, self._params, self._post_criteria
+ )
context.statement.use_labels = True
if context.autoflush and not context.populate_existing:
@@ -426,7 +442,7 @@ class Result(object):
"""
- col = func.count(literal_column('*'))
+ col = func.count(literal_column("*"))
bq = self.bq.with_criteria(lambda q: q.from_self(col))
return bq.for_session(self.session).params(self._params).scalar()
@@ -456,8 +472,10 @@ class Result(object):
"""
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
ret = list(
- bq.for_session(self.session).params(self._params).
- _using_post_criteria(self._post_criteria))
+ bq.for_session(self.session)
+ .params(self._params)
+ ._using_post_criteria(self._post_criteria)
+ )
if len(ret) > 0:
return ret[0]
else:
@@ -473,7 +491,8 @@ class Result(object):
ret = self.one_or_none()
except orm_exc.MultipleResultsFound:
raise orm_exc.MultipleResultsFound(
- "Multiple rows were found for one()")
+ "Multiple rows were found for one()"
+ )
else:
if ret is None:
raise orm_exc.NoResultFound("No row was found for one()")
@@ -497,7 +516,8 @@ class Result(object):
return None
else:
raise orm_exc.MultipleResultsFound(
- "Multiple rows were found for one_or_none()")
+ "Multiple rows were found for one_or_none()"
+ )
def all(self):
"""Return all rows.
@@ -533,13 +553,18 @@ class Result(object):
# None present in ident - turn those comparisons
# into "IS NULL"
if None in primary_key_identity:
- nones = set([
- _get_params[col].key for col, value in
- zip(mapper.primary_key, primary_key_identity)
- if value is None
- ])
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
_lcl_get_clause = sql_util.adapt_criterion_to_null(
- _lcl_get_clause, nones)
+ _lcl_get_clause, nones
+ )
_lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False)
q._criterion = _lcl_get_clause
@@ -556,16 +581,20 @@ class Result(object):
# key so that if a race causes multiple calls to _get_clause,
# we've cached on ours
bq = bq._clone()
- bq._cache_key += (_get_clause, )
+ bq._cache_key += (_get_clause,)
bq = bq.with_criteria(
- setup, tuple(elem is None for elem in primary_key_identity))
+ setup, tuple(elem is None for elem in primary_key_identity)
+ )
- params = dict([
- (_get_params[primary_key].key, id_val)
- for id_val, primary_key
- in zip(primary_key_identity, mapper.primary_key)
- ])
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
result = list(bq.for_session(self.session).params(**params))
l = len(result)
@@ -578,7 +607,8 @@ class Result(object):
@util.deprecated(
- "1.2", "Baked lazy loading is now the default implementation.")
+ "1.2", "Baked lazy loading is now the default implementation."
+)
def bake_lazy_loaders():
"""Enable the use of baked queries for all lazyloaders systemwide.
@@ -590,7 +620,8 @@ def bake_lazy_loaders():
@util.deprecated(
- "1.2", "Baked lazy loading is now the default implementation.")
+ "1.2", "Baked lazy loading is now the default implementation."
+)
def unbake_lazy_loaders():
"""Disable the use of baked queries for all lazyloaders systemwide.
@@ -601,7 +632,8 @@ def unbake_lazy_loaders():
"""
raise NotImplementedError(
- "Baked lazy loading is now the default implementation")
+ "Baked lazy loading is now the default implementation"
+ )
@strategy_options.loader_option()
@@ -615,20 +647,27 @@ def baked_lazyload(loadopt, attr):
@baked_lazyload._add_unbound_fn
@util.deprecated(
- "1.2", "Baked lazy loading is now the default "
- "implementation for lazy loading.")
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
def baked_lazyload(*keys):
return strategy_options._UnboundLoad._from_keys(
- strategy_options._UnboundLoad.baked_lazyload, keys, False, {})
+ strategy_options._UnboundLoad.baked_lazyload, keys, False, {}
+ )
@baked_lazyload._add_unbound_all_fn
@util.deprecated(
- "1.2", "Baked lazy loading is now the default "
- "implementation for lazy loading.")
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
def baked_lazyload_all(*keys):
return strategy_options._UnboundLoad._from_keys(
- strategy_options._UnboundLoad.baked_lazyload, keys, True, {})
+ strategy_options._UnboundLoad.baked_lazyload, keys, True, {}
+ )
+
baked_lazyload = baked_lazyload._unbound_fn
baked_lazyload_all = baked_lazyload_all._unbound_all_fn
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
index 6a0909d36..220b2c057 100644
--- a/lib/sqlalchemy/ext/compiler.py
+++ b/lib/sqlalchemy/ext/compiler.py
@@ -407,37 +407,44 @@ def compiles(class_, *specs):
def decorate(fn):
# get an existing @compiles handler
- existing = class_.__dict__.get('_compiler_dispatcher', None)
+ existing = class_.__dict__.get("_compiler_dispatcher", None)
# get the original handler. All ClauseElement classes have one
# of these, but some TypeEngine classes will not.
- existing_dispatch = getattr(class_, '_compiler_dispatch', None)
+ existing_dispatch = getattr(class_, "_compiler_dispatch", None)
if not existing:
existing = _dispatcher()
if existing_dispatch:
+
def _wrap_existing_dispatch(element, compiler, **kw):
try:
return existing_dispatch(element, compiler, **kw)
except exc.UnsupportedCompilationError:
raise exc.CompileError(
"%s construct has no default "
- "compilation handler." % type(element))
- existing.specs['default'] = _wrap_existing_dispatch
+ "compilation handler." % type(element)
+ )
+
+ existing.specs["default"] = _wrap_existing_dispatch
# TODO: why is the lambda needed ?
- setattr(class_, '_compiler_dispatch',
- lambda *arg, **kw: existing(*arg, **kw))
- setattr(class_, '_compiler_dispatcher', existing)
+ setattr(
+ class_,
+ "_compiler_dispatch",
+ lambda *arg, **kw: existing(*arg, **kw),
+ )
+ setattr(class_, "_compiler_dispatcher", existing)
if specs:
for s in specs:
existing.specs[s] = fn
else:
- existing.specs['default'] = fn
+ existing.specs["default"] = fn
return fn
+
return decorate
@@ -445,7 +452,7 @@ def deregister(class_):
"""Remove all custom compilers associated with a given
:class:`.ClauseElement` type."""
- if hasattr(class_, '_compiler_dispatcher'):
+ if hasattr(class_, "_compiler_dispatcher"):
# regenerate default _compiler_dispatch
visitors._generate_dispatch(class_)
# remove custom directive
@@ -461,10 +468,11 @@ class _dispatcher(object):
fn = self.specs.get(compiler.dialect.name, None)
if not fn:
try:
- fn = self.specs['default']
+ fn = self.specs["default"]
except KeyError:
raise exc.CompileError(
"%s construct has no default "
- "compilation handler." % type(element))
+ "compilation handler." % type(element)
+ )
return fn(element, compiler, **kw)
diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py
index cb81f51e5..2b0a37884 100644
--- a/lib/sqlalchemy/ext/declarative/__init__.py
+++ b/lib/sqlalchemy/ext/declarative/__init__.py
@@ -5,14 +5,31 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from .api import declarative_base, synonym_for, comparable_using, \
- instrument_declarative, ConcreteBase, AbstractConcreteBase, \
- DeclarativeMeta, DeferredReflection, has_inherited_table,\
- declared_attr, as_declarative
+from .api import (
+ declarative_base,
+ synonym_for,
+ comparable_using,
+ instrument_declarative,
+ ConcreteBase,
+ AbstractConcreteBase,
+ DeclarativeMeta,
+ DeferredReflection,
+ has_inherited_table,
+ declared_attr,
+ as_declarative,
+)
-__all__ = ['declarative_base', 'synonym_for', 'has_inherited_table',
- 'comparable_using', 'instrument_declarative', 'declared_attr',
- 'as_declarative',
- 'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta',
- 'DeferredReflection']
+__all__ = [
+ "declarative_base",
+ "synonym_for",
+ "has_inherited_table",
+ "comparable_using",
+ "instrument_declarative",
+ "declared_attr",
+ "as_declarative",
+ "ConcreteBase",
+ "AbstractConcreteBase",
+ "DeclarativeMeta",
+ "DeferredReflection",
+]
diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py
index 865cd16f0..987e92119 100644
--- a/lib/sqlalchemy/ext/declarative/api.py
+++ b/lib/sqlalchemy/ext/declarative/api.py
@@ -8,9 +8,13 @@
from ...schema import Table, MetaData, Column
-from ...orm import synonym as _orm_synonym, \
- comparable_property,\
- interfaces, properties, attributes
+from ...orm import (
+ synonym as _orm_synonym,
+ comparable_property,
+ interfaces,
+ properties,
+ attributes,
+)
from ...orm.util import polymorphic_union
from ...orm.base import _mapper_or_none
from ...util import OrderedDict, hybridmethod, hybridproperty
@@ -19,9 +23,13 @@ from ... import exc
import weakref
import re
-from .base import _as_declarative, \
- _declarative_constructor,\
- _DeferredMapperConfig, _add_attribute, _del_attribute
+from .base import (
+ _as_declarative,
+ _declarative_constructor,
+ _DeferredMapperConfig,
+ _add_attribute,
+ _del_attribute,
+)
from .clsregistry import _class_resolver
@@ -31,10 +39,10 @@ def instrument_declarative(cls, registry, metadata):
MetaData object.
"""
- if '_decl_class_registry' in cls.__dict__:
+ if "_decl_class_registry" in cls.__dict__:
raise exc.InvalidRequestError(
- "Class %r already has been "
- "instrumented declaratively" % cls)
+ "Class %r already has been " "instrumented declaratively" % cls
+ )
cls._decl_class_registry = registry
cls.metadata = metadata
_as_declarative(cls, cls.__name__, cls.__dict__)
@@ -54,14 +62,14 @@ def has_inherited_table(cls):
"""
for class_ in cls.__mro__[1:]:
- if getattr(class_, '__table__', None) is not None:
+ if getattr(class_, "__table__", None) is not None:
return True
return False
class DeclarativeMeta(type):
def __init__(cls, classname, bases, dict_):
- if '_decl_class_registry' not in cls.__dict__:
+ if "_decl_class_registry" not in cls.__dict__:
_as_declarative(cls, classname, cls.__dict__)
type.__init__(cls, classname, bases, dict_)
@@ -71,6 +79,7 @@ class DeclarativeMeta(type):
def __delattr__(cls, key):
_del_attribute(cls, key)
+
def synonym_for(name, map_column=False):
"""Decorator that produces an :func:`.orm.synonym` attribute in conjunction
with a Python descriptor.
@@ -104,8 +113,10 @@ def synonym_for(name, map_column=False):
can be achieved with synonyms.
"""
+
def decorate(fn):
return _orm_synonym(name, map_column=map_column, descriptor=fn)
+
return decorate
@@ -127,8 +138,10 @@ def comparable_using(comparator_factory):
prop = comparable_property(MyComparatorType)
"""
+
def decorate(fn):
return comparable_property(comparator_factory, fn)
+
return decorate
@@ -190,14 +203,16 @@ class declared_attr(interfaces._MappedAttribute, property):
self._cascading = cascading
def __get__(desc, self, cls):
- reg = cls.__dict__.get('_sa_declared_attr_reg', None)
+ reg = cls.__dict__.get("_sa_declared_attr_reg", None)
if reg is None:
- if not re.match(r'^__.+__$', desc.fget.__name__) and \
- attributes.manager_of_class(cls) is None:
+ if (
+ not re.match(r"^__.+__$", desc.fget.__name__)
+ and attributes.manager_of_class(cls) is None
+ ):
util.warn(
"Unmanaged access of declarative attribute %s from "
- "non-mapped class %s" %
- (desc.fget.__name__, cls.__name__))
+ "non-mapped class %s" % (desc.fget.__name__, cls.__name__)
+ )
return desc.fget(cls)
elif desc in reg:
return reg[desc]
@@ -283,10 +298,16 @@ class _stateful_declared_attr(declared_attr):
return declared_attr(fn, **self.kw)
-def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
- name='Base', constructor=_declarative_constructor,
- class_registry=None,
- metaclass=DeclarativeMeta):
+def declarative_base(
+ bind=None,
+ metadata=None,
+ mapper=None,
+ cls=object,
+ name="Base",
+ constructor=_declarative_constructor,
+ class_registry=None,
+ metaclass=DeclarativeMeta,
+):
r"""Construct a base class for declarative class definitions.
The new base class will be given a metaclass that produces
@@ -357,16 +378,17 @@ def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
class_registry = weakref.WeakValueDictionary()
bases = not isinstance(cls, tuple) and (cls,) or cls
- class_dict = dict(_decl_class_registry=class_registry,
- metadata=lcl_metadata)
+ class_dict = dict(
+ _decl_class_registry=class_registry, metadata=lcl_metadata
+ )
if isinstance(cls, type):
- class_dict['__doc__'] = cls.__doc__
+ class_dict["__doc__"] = cls.__doc__
if constructor:
- class_dict['__init__'] = constructor
+ class_dict["__init__"] = constructor
if mapper:
- class_dict['__mapper_cls__'] = mapper
+ class_dict["__mapper_cls__"] = mapper
return metaclass(name, bases, class_dict)
@@ -401,9 +423,10 @@ def as_declarative(**kw):
:func:`.declarative_base`
"""
+
def decorate(cls):
- kw['cls'] = cls
- kw['name'] = cls.__name__
+ kw["cls"] = cls
+ kw["name"] = cls.__name__
return declarative_base(**kw)
return decorate
@@ -456,10 +479,13 @@ class ConcreteBase(object):
@classmethod
def _create_polymorphic_union(cls, mappers):
- return polymorphic_union(OrderedDict(
- (mp.polymorphic_identity, mp.local_table)
- for mp in mappers
- ), 'type', 'pjoin')
+ return polymorphic_union(
+ OrderedDict(
+ (mp.polymorphic_identity, mp.local_table) for mp in mappers
+ ),
+ "type",
+ "pjoin",
+ )
@classmethod
def __declare_first__(cls):
@@ -568,7 +594,7 @@ class AbstractConcreteBase(ConcreteBase):
@classmethod
def _sa_decl_prepare_nocascade(cls):
- if getattr(cls, '__mapper__', None):
+ if getattr(cls, "__mapper__", None):
return
to_map = _DeferredMapperConfig.config_for_cls(cls)
@@ -604,8 +630,9 @@ class AbstractConcreteBase(ConcreteBase):
def mapper_args():
args = m_args()
- args['polymorphic_on'] = pjoin.c.type
+ args["polymorphic_on"] = pjoin.c.type
return args
+
to_map.mapper_args_fn = mapper_args
m = to_map.map()
@@ -684,6 +711,7 @@ class DeferredReflection(object):
.. versionadded:: 0.8
"""
+
@classmethod
def prepare(cls, engine):
"""Reflect all :class:`.Table` objects for all current
@@ -696,8 +724,10 @@ class DeferredReflection(object):
mapper = thingy.cls.__mapper__
metadata = mapper.class_.metadata
for rel in mapper._props.values():
- if isinstance(rel, properties.RelationshipProperty) and \
- rel.secondary is not None:
+ if (
+ isinstance(rel, properties.RelationshipProperty)
+ and rel.secondary is not None
+ ):
if isinstance(rel.secondary, Table):
cls._reflect_table(rel.secondary, engine)
elif isinstance(rel.secondary, _class_resolver):
@@ -711,6 +741,7 @@ class DeferredReflection(object):
t1 = Table(key, metadata)
cls._reflect_table(t1, engine)
return t1
+
return _resolve
@classmethod
@@ -724,10 +755,12 @@ class DeferredReflection(object):
@classmethod
def _reflect_table(cls, table, engine):
- Table(table.name,
- table.metadata,
- extend_existing=True,
- autoload_replace=False,
- autoload=True,
- autoload_with=engine,
- schema=table.schema)
+ Table(
+ table.name,
+ table.metadata,
+ extend_existing=True,
+ autoload_replace=False,
+ autoload=True,
+ autoload_with=engine,
+ schema=table.schema,
+ )
diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py
index f27314b5e..07778f733 100644
--- a/lib/sqlalchemy/ext/declarative/base.py
+++ b/lib/sqlalchemy/ext/declarative/base.py
@@ -39,7 +39,7 @@ def _resolve_for_abstract_or_classical(cls):
if cls is object:
return None
- if _get_immediate_cls_attr(cls, '__abstract__', strict=True):
+ if _get_immediate_cls_attr(cls, "__abstract__", strict=True):
for sup in cls.__bases__:
sup = _resolve_for_abstract_or_classical(sup)
if sup is not None:
@@ -59,7 +59,7 @@ def _dive_for_classically_mapped_class(cls):
# if we are within a base hierarchy, don't
# search at all for classical mappings
- if hasattr(cls, '_decl_class_registry'):
+ if hasattr(cls, "_decl_class_registry"):
return None
manager = instrumentation.manager_of_class(cls)
@@ -89,15 +89,19 @@ def _get_immediate_cls_attr(cls, attrname, strict=False):
return None
for base in cls.__mro__:
- _is_declarative_inherits = hasattr(base, '_decl_class_registry')
- _is_classicial_inherits = not _is_declarative_inherits and \
- _dive_for_classically_mapped_class(base) is not None
+ _is_declarative_inherits = hasattr(base, "_decl_class_registry")
+ _is_classicial_inherits = (
+ not _is_declarative_inherits
+ and _dive_for_classically_mapped_class(base) is not None
+ )
if attrname in base.__dict__ and (
- base is cls or
- ((base in cls.__bases__ if strict else True)
+ base is cls
+ or (
+ (base in cls.__bases__ if strict else True)
and not _is_declarative_inherits
- and not _is_classicial_inherits)
+ and not _is_classicial_inherits
+ )
):
return getattr(base, attrname)
else:
@@ -108,9 +112,10 @@ def _as_declarative(cls, classname, dict_):
global declared_attr, declarative_props
if declared_attr is None:
from .api import declared_attr
+
declarative_props = (declared_attr, util.classproperty)
- if _get_immediate_cls_attr(cls, '__abstract__', strict=True):
+ if _get_immediate_cls_attr(cls, "__abstract__", strict=True):
return
_MapperConfig.setup_mapping(cls, classname, dict_)
@@ -119,23 +124,23 @@ def _as_declarative(cls, classname, dict_):
def _check_declared_props_nocascade(obj, name, cls):
if isinstance(obj, declarative_props):
- if getattr(obj, '_cascading', False):
+ if getattr(obj, "_cascading", False):
util.warn(
"@declared_attr.cascading is not supported on the %s "
"attribute on class %s. This attribute invokes for "
- "subclasses in any case." % (name, cls))
+ "subclasses in any case." % (name, cls)
+ )
return True
else:
return False
class _MapperConfig(object):
-
@classmethod
def setup_mapping(cls, cls_, classname, dict_):
defer_map = _get_immediate_cls_attr(
- cls_, '_sa_decl_prepare_nocascade', strict=True) or \
- hasattr(cls_, '_sa_decl_prepare')
+ cls_, "_sa_decl_prepare_nocascade", strict=True
+ ) or hasattr(cls_, "_sa_decl_prepare")
if defer_map:
cfg_cls = _DeferredMapperConfig
@@ -179,12 +184,14 @@ class _MapperConfig(object):
self.map()
def _setup_declared_events(self):
- if _get_immediate_cls_attr(self.cls, '__declare_last__'):
+ if _get_immediate_cls_attr(self.cls, "__declare_last__"):
+
@event.listens_for(mapper, "after_configured")
def after_configured():
self.cls.__declare_last__()
- if _get_immediate_cls_attr(self.cls, '__declare_first__'):
+ if _get_immediate_cls_attr(self.cls, "__declare_first__"):
+
@event.listens_for(mapper, "before_configured")
def before_configured():
self.cls.__declare_first__()
@@ -198,59 +205,62 @@ class _MapperConfig(object):
tablename = None
for base in cls.__mro__:
- class_mapped = base is not cls and \
- _declared_mapping_info(base) is not None and \
- not _get_immediate_cls_attr(
- base, '_sa_decl_prepare_nocascade', strict=True)
+ class_mapped = (
+ base is not cls
+ and _declared_mapping_info(base) is not None
+ and not _get_immediate_cls_attr(
+ base, "_sa_decl_prepare_nocascade", strict=True
+ )
+ )
if not class_mapped and base is not cls:
self._produce_column_copies(base)
for name, obj in vars(base).items():
- if name == '__mapper_args__':
- check_decl = \
- _check_declared_props_nocascade(obj, name, cls)
- if not mapper_args_fn and (
- not class_mapped or
- check_decl
- ):
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (not class_mapped or check_decl):
# don't even invoke __mapper_args__ until
# after we've determined everything about the
# mapped table.
# make a copy of it so a class-level dictionary
# is not overwritten when we update column-based
# arguments.
- mapper_args_fn = lambda: dict(cls.__mapper_args__) # noqa
- elif name == '__tablename__':
- check_decl = \
- _check_declared_props_nocascade(obj, name, cls)
- if not tablename and (
- not class_mapped or
- check_decl
- ):
+ mapper_args_fn = lambda: dict(
+ cls.__mapper_args__
+ ) # noqa
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
tablename = cls.__tablename__
- elif name == '__table_args__':
- check_decl = \
- _check_declared_props_nocascade(obj, name, cls)
- if not table_args and (
- not class_mapped or
- check_decl
- ):
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
table_args = cls.__table_args__
if not isinstance(
- table_args, (tuple, dict, type(None))):
+ table_args, (tuple, dict, type(None))
+ ):
raise exc.ArgumentError(
"__table_args__ value must be a tuple, "
- "dict, or None")
+ "dict, or None"
+ )
if base is not cls:
inherited_table_args = True
elif class_mapped:
if isinstance(obj, declarative_props):
- util.warn("Regular (i.e. not __special__) "
- "attribute '%s.%s' uses @declared_attr, "
- "but owning class %s is mapped - "
- "not applying to subclass %s."
- % (base.__name__, name, base, cls))
+ util.warn(
+ "Regular (i.e. not __special__) "
+ "attribute '%s.%s' uses @declared_attr, "
+ "but owning class %s is mapped - "
+ "not applying to subclass %s."
+ % (base.__name__, name, base, cls)
+ )
continue
elif base is not cls:
# we're a mixin, abstract base, or something that is
@@ -263,7 +273,8 @@ class _MapperConfig(object):
"Mapper properties (i.e. deferred,"
"column_property(), relationship(), etc.) must "
"be declared as @declared_attr callables "
- "on declarative mixin classes.")
+ "on declarative mixin classes."
+ )
elif isinstance(obj, declarative_props):
oldclassprop = isinstance(obj, util.classproperty)
if not oldclassprop and obj._cascading:
@@ -278,15 +289,18 @@ class _MapperConfig(object):
"Attribute '%s' on class %s cannot be "
"processed due to "
"@declared_attr.cascading; "
- "skipping" % (name, cls))
- dict_[name] = column_copies[obj] = \
- ret = obj.__get__(obj, cls)
+ "skipping" % (name, cls)
+ )
+ dict_[name] = column_copies[
+ obj
+ ] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
else:
if oldclassprop:
util.warn_deprecated(
"Use of sqlalchemy.util.classproperty on "
- "declarative classes is deprecated.")
+ "declarative classes is deprecated."
+ )
# access attribute using normal class access
ret = getattr(cls, name)
@@ -294,14 +308,20 @@ class _MapperConfig(object):
# or similar. note there is no known case that
# produces nested proxies, so we are only
# looking one level deep right now.
- if isinstance(ret, InspectionAttr) and \
- ret._is_internal_proxy and not isinstance(
- ret.original_property, MapperProperty):
+ if (
+ isinstance(ret, InspectionAttr)
+ and ret._is_internal_proxy
+ and not isinstance(
+ ret.original_property, MapperProperty
+ )
+ ):
ret = ret.descriptor
dict_[name] = column_copies[obj] = ret
- if isinstance(ret, (Column, MapperProperty)) and \
- ret.doc is None:
+ if (
+ isinstance(ret, (Column, MapperProperty))
+ and ret.doc is None
+ ):
ret.doc = obj.__doc__
# here, the attribute is some other kind of property that
# we assume is not part of the declarative mapping.
@@ -321,8 +341,9 @@ class _MapperConfig(object):
util.warn(
"Attribute '%s' on class %s appears to be a non-schema "
"'sqlalchemy.sql.column()' "
- "object; this won't be part of the declarative mapping" %
- (key, cls))
+ "object; this won't be part of the declarative mapping"
+ % (key, cls)
+ )
def _produce_column_copies(self, base):
cls = self.cls
@@ -340,10 +361,11 @@ class _MapperConfig(object):
raise exc.InvalidRequestError(
"Columns with foreign keys to other columns "
"must be declared as @declared_attr callables "
- "on declarative mixin classes. ")
+ "on declarative mixin classes. "
+ )
elif name not in dict_ and not (
- '__table__' in dict_ and
- (obj.name or name) in dict_['__table__'].c
+ "__table__" in dict_
+ and (obj.name or name) in dict_["__table__"].c
):
column_copies[obj] = copy_ = obj.copy()
copy_._creation_order = obj._creation_order
@@ -357,11 +379,12 @@ class _MapperConfig(object):
our_stuff = self.properties
late_mapped = _get_immediate_cls_attr(
- cls, '_sa_decl_prepare_nocascade', strict=True)
+ cls, "_sa_decl_prepare_nocascade", strict=True
+ )
for k in list(dict_):
- if k in ('__table__', '__tablename__', '__mapper_args__'):
+ if k in ("__table__", "__tablename__", "__mapper_args__"):
continue
value = dict_[k]
@@ -371,29 +394,37 @@ class _MapperConfig(object):
"Use of @declared_attr.cascading only applies to "
"Declarative 'mixin' and 'abstract' classes. "
"Currently, this flag is ignored on mapped class "
- "%s" % self.cls)
+ "%s" % self.cls
+ )
value = getattr(cls, k)
- elif isinstance(value, QueryableAttribute) and \
- value.class_ is not cls and \
- value.key != k:
+ elif (
+ isinstance(value, QueryableAttribute)
+ and value.class_ is not cls
+ and value.key != k
+ ):
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = synonym(value.key)
setattr(cls, k, value)
- if (isinstance(value, tuple) and len(value) == 1 and
- isinstance(value[0], (Column, MapperProperty))):
- util.warn("Ignoring declarative-like tuple value of attribute "
- "'%s': possibly a copy-and-paste error with a comma "
- "accidentally placed at the end of the line?" % k)
+ if (
+ isinstance(value, tuple)
+ and len(value) == 1
+ and isinstance(value[0], (Column, MapperProperty))
+ ):
+ util.warn(
+ "Ignoring declarative-like tuple value of attribute "
+ "'%s': possibly a copy-and-paste error with a comma "
+ "accidentally placed at the end of the line?" % k
+ )
continue
elif not isinstance(value, (Column, MapperProperty)):
# using @declared_attr for some object that
# isn't Column/MapperProperty; remove from the dict_
# and place the evaluated value onto the class.
- if not k.startswith('__'):
+ if not k.startswith("__"):
dict_.pop(k)
self._warn_for_decl_attributes(cls, k, value)
if not late_mapped:
@@ -402,7 +433,7 @@ class _MapperConfig(object):
# we expect to see the name 'metadata' in some valid cases;
# however at this point we see it's assigned to something trying
# to be mapped, so raise for that.
- elif k == 'metadata':
+ elif k == "metadata":
raise exc.InvalidRequestError(
"Attribute name 'metadata' is reserved "
"for the MetaData instance when using a "
@@ -423,8 +454,7 @@ class _MapperConfig(object):
for key, c in list(our_stuff.items()):
if isinstance(c, (ColumnProperty, CompositeProperty)):
for col in c.columns:
- if isinstance(col, Column) and \
- col.table is None:
+ if isinstance(col, Column) and col.table is None:
_undefer_column_name(key, col)
if not isinstance(c, CompositeProperty):
name_to_prop_key[col.name].add(key)
@@ -447,8 +477,8 @@ class _MapperConfig(object):
"On class %r, Column object %r named "
"directly multiple times, "
"only one will be used: %s. "
- "Consider using orm.synonym instead" %
- (self.classname, name, (", ".join(sorted(keys))))
+ "Consider using orm.synonym instead"
+ % (self.classname, name, (", ".join(sorted(keys))))
)
def _setup_table(self):
@@ -459,15 +489,16 @@ class _MapperConfig(object):
declared_columns = self.declared_columns
declared_columns = self.declared_columns = sorted(
- declared_columns, key=lambda c: c._creation_order)
+ declared_columns, key=lambda c: c._creation_order
+ )
table = None
- if hasattr(cls, '__table_cls__'):
+ if hasattr(cls, "__table_cls__"):
table_cls = util.unbound_method_to_callable(cls.__table_cls__)
else:
table_cls = Table
- if '__table__' not in dict_:
+ if "__table__" not in dict_:
if tablename is not None:
args, table_kw = (), {}
@@ -480,14 +511,16 @@ class _MapperConfig(object):
else:
args = table_args
- autoload = dict_.get('__autoload__')
+ autoload = dict_.get("__autoload__")
if autoload:
- table_kw['autoload'] = True
+ table_kw["autoload"] = True
cls.__table__ = table = table_cls(
- tablename, cls.metadata,
+ tablename,
+ cls.metadata,
*(tuple(declared_columns) + tuple(args)),
- **table_kw)
+ **table_kw
+ )
else:
table = cls.__table__
if declared_columns:
@@ -512,21 +545,27 @@ class _MapperConfig(object):
c = _resolve_for_abstract_or_classical(c)
if c is None:
continue
- if _declared_mapping_info(c) is not None and \
- not _get_immediate_cls_attr(
- c, '_sa_decl_prepare_nocascade', strict=True):
+ if _declared_mapping_info(
+ c
+ ) is not None and not _get_immediate_cls_attr(
+ c, "_sa_decl_prepare_nocascade", strict=True
+ ):
inherits.append(c)
if inherits:
if len(inherits) > 1:
raise exc.InvalidRequestError(
- "Class %s has multiple mapped bases: %r" % (cls, inherits))
+ "Class %s has multiple mapped bases: %r" % (cls, inherits)
+ )
self.inherits = inherits[0]
else:
self.inherits = None
- if table is None and self.inherits is None and \
- not _get_immediate_cls_attr(cls, '__no_table__'):
+ if (
+ table is None
+ and self.inherits is None
+ and not _get_immediate_cls_attr(cls, "__no_table__")
+ ):
raise exc.InvalidRequestError(
"Class %r does not have a __table__ or __tablename__ "
@@ -553,8 +592,8 @@ class _MapperConfig(object):
continue
raise exc.ArgumentError(
"Column '%s' on class %s conflicts with "
- "existing column '%s'" %
- (c, cls, inherited_table.c[c.name])
+ "existing column '%s'"
+ % (c, cls, inherited_table.c[c.name])
)
if c.primary_key:
raise exc.ArgumentError(
@@ -562,8 +601,10 @@ class _MapperConfig(object):
"class with no table."
)
inherited_table.append_column(c)
- if inherited_mapped_table is not None and \
- inherited_mapped_table is not inherited_table:
+ if (
+ inherited_mapped_table is not None
+ and inherited_mapped_table is not inherited_table
+ ):
inherited_mapped_table._refresh_for_new_column(c)
def _prepare_mapper_arguments(self):
@@ -575,18 +616,19 @@ class _MapperConfig(object):
# make sure that column copies are used rather
# than the original columns from any mixins
- for k in ('version_id_col', 'polymorphic_on',):
+ for k in ("version_id_col", "polymorphic_on"):
if k in mapper_args:
v = mapper_args[k]
mapper_args[k] = self.column_copies.get(v, v)
- assert 'inherits' not in mapper_args, \
- "Can't specify 'inherits' explicitly with declarative mappings"
+ assert (
+ "inherits" not in mapper_args
+ ), "Can't specify 'inherits' explicitly with declarative mappings"
if self.inherits:
- mapper_args['inherits'] = self.inherits
+ mapper_args["inherits"] = self.inherits
- if self.inherits and not mapper_args.get('concrete', False):
+ if self.inherits and not mapper_args.get("concrete", False):
# single or joined inheritance
# exclude any cols on the inherited table which are
# not mapped on the parent class, to avoid
@@ -594,16 +636,17 @@ class _MapperConfig(object):
inherited_mapper = _declared_mapping_info(self.inherits)
inherited_table = inherited_mapper.local_table
- if 'exclude_properties' not in mapper_args:
- mapper_args['exclude_properties'] = exclude_properties = \
- set(
- [c.key for c in inherited_table.c
- if c not in inherited_mapper._columntoproperty]
- ).union(
- inherited_mapper.exclude_properties or ()
- )
+ if "exclude_properties" not in mapper_args:
+ mapper_args["exclude_properties"] = exclude_properties = set(
+ [
+ c.key
+ for c in inherited_table.c
+ if c not in inherited_mapper._columntoproperty
+ ]
+ ).union(inherited_mapper.exclude_properties or ())
exclude_properties.difference_update(
- [c.key for c in self.declared_columns])
+ [c.key for c in self.declared_columns]
+ )
# look through columns in the current mapper that
# are keyed to a propname different than the colname
@@ -621,21 +664,20 @@ class _MapperConfig(object):
# first. See [ticket:1892] for background.
properties[k] = [col] + p.columns
result_mapper_args = mapper_args.copy()
- result_mapper_args['properties'] = properties
+ result_mapper_args["properties"] = properties
self.mapper_args = result_mapper_args
def map(self):
self._prepare_mapper_arguments()
- if hasattr(self.cls, '__mapper_cls__'):
+ if hasattr(self.cls, "__mapper_cls__"):
mapper_cls = util.unbound_method_to_callable(
- self.cls.__mapper_cls__)
+ self.cls.__mapper_cls__
+ )
else:
mapper_cls = mapper
self.cls.__mapper__ = mp_ = mapper_cls(
- self.cls,
- self.local_table,
- **self.mapper_args
+ self.cls, self.local_table, **self.mapper_args
)
del self.cls._sa_declared_attr_reg
return mp_
@@ -663,8 +705,7 @@ class _DeferredMapperConfig(_MapperConfig):
@classmethod
def has_cls(cls, class_):
# 2.6 fails on weakref if class_ is an old style class
- return isinstance(class_, type) and \
- weakref.ref(class_) in cls._configs
+ return isinstance(class_, type) and weakref.ref(class_) in cls._configs
@classmethod
def config_for_cls(cls, class_):
@@ -673,18 +714,15 @@ class _DeferredMapperConfig(_MapperConfig):
@classmethod
def classes_for_base(cls, base_cls, sort=True):
classes_for_base = [
- m for m, cls_ in
- [(m, m.cls) for m in cls._configs.values()]
+ m
+ for m, cls_ in [(m, m.cls) for m in cls._configs.values()]
if cls_ is not None and issubclass(cls_, base_cls)
]
if not sort:
return classes_for_base
- all_m_by_cls = dict(
- (m.cls, m)
- for m in classes_for_base
- )
+ all_m_by_cls = dict((m.cls, m) for m in classes_for_base)
tuples = []
for m_cls in all_m_by_cls:
@@ -693,12 +731,7 @@ class _DeferredMapperConfig(_MapperConfig):
for base_cls in m_cls.__bases__
if base_cls in all_m_by_cls
)
- return list(
- topological.sort(
- tuples,
- classes_for_base
- )
- )
+ return list(topological.sort(tuples, classes_for_base))
def map(self):
self._configs.pop(self._cls, None)
@@ -713,7 +746,7 @@ def _add_attribute(cls, key, value):
"""
- if '__mapper__' in cls.__dict__:
+ if "__mapper__" in cls.__dict__:
if isinstance(value, Column):
_undefer_column_name(key, value)
cls.__table__.append_column(value)
@@ -726,16 +759,14 @@ def _add_attribute(cls, key, value):
cls.__mapper__.add_property(key, value)
elif isinstance(value, MapperProperty):
cls.__mapper__.add_property(
- key,
- clsregistry._deferred_relationship(cls, value)
+ key, clsregistry._deferred_relationship(cls, value)
)
elif isinstance(value, QueryableAttribute) and value.key != key:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = synonym(value.key)
cls.__mapper__.add_property(
- key,
- clsregistry._deferred_relationship(cls, value)
+ key, clsregistry._deferred_relationship(cls, value)
)
else:
type.__setattr__(cls, key, value)
@@ -746,15 +777,18 @@ def _add_attribute(cls, key, value):
def _del_attribute(cls, key):
- if '__mapper__' in cls.__dict__ and \
- key in cls.__dict__ and not cls.__mapper__._dispose_called:
+ if (
+ "__mapper__" in cls.__dict__
+ and key in cls.__dict__
+ and not cls.__mapper__._dispose_called
+ ):
value = cls.__dict__[key]
if isinstance(
- value,
- (Column, ColumnProperty, MapperProperty, QueryableAttribute)
+ value, (Column, ColumnProperty, MapperProperty, QueryableAttribute)
):
raise NotImplementedError(
- "Can't un-map individual mapped attributes on a mapped class.")
+ "Can't un-map individual mapped attributes on a mapped class."
+ )
else:
type.__delattr__(cls, key)
cls.__mapper__._expire_memoizations()
@@ -776,10 +810,12 @@ def _declarative_constructor(self, **kwargs):
for k in kwargs:
if not hasattr(cls_, k):
raise TypeError(
- "%r is an invalid keyword argument for %s" %
- (k, cls_.__name__))
+ "%r is an invalid keyword argument for %s" % (k, cls_.__name__)
+ )
setattr(self, k, kwargs[k])
-_declarative_constructor.__name__ = '__init__'
+
+
+_declarative_constructor.__name__ = "__init__"
def _undefer_column_name(key, column):
diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py
index e941b9ed3..c52ae4a2f 100644
--- a/lib/sqlalchemy/ext/declarative/clsregistry.py
+++ b/lib/sqlalchemy/ext/declarative/clsregistry.py
@@ -10,8 +10,11 @@ This system allows specification of classes and expressions used in
:func:`.relationship` using strings.
"""
-from ...orm.properties import ColumnProperty, RelationshipProperty, \
- SynonymProperty
+from ...orm.properties import (
+ ColumnProperty,
+ RelationshipProperty,
+ SynonymProperty,
+)
from ...schema import _get_table_key
from ...orm import class_mapper, interfaces
from ... import util
@@ -35,17 +38,18 @@ def add_class(classname, cls):
# class already exists.
existing = cls._decl_class_registry[classname]
if not isinstance(existing, _MultipleClassMarker):
- existing = \
- cls._decl_class_registry[classname] = \
- _MultipleClassMarker([cls, existing])
+ existing = cls._decl_class_registry[
+ classname
+ ] = _MultipleClassMarker([cls, existing])
else:
cls._decl_class_registry[classname] = cls
try:
- root_module = cls._decl_class_registry['_sa_module_registry']
+ root_module = cls._decl_class_registry["_sa_module_registry"]
except KeyError:
- cls._decl_class_registry['_sa_module_registry'] = \
- root_module = _ModuleMarker('_sa_module_registry', None)
+ cls._decl_class_registry[
+ "_sa_module_registry"
+ ] = root_module = _ModuleMarker("_sa_module_registry", None)
tokens = cls.__module__.split(".")
@@ -71,12 +75,13 @@ class _MultipleClassMarker(object):
"""
- __slots__ = 'on_remove', 'contents', '__weakref__'
+ __slots__ = "on_remove", "contents", "__weakref__"
def __init__(self, classes, on_remove=None):
self.on_remove = on_remove
- self.contents = set([
- weakref.ref(item, self._remove_item) for item in classes])
+ self.contents = set(
+ [weakref.ref(item, self._remove_item) for item in classes]
+ )
_registries.add(self)
def __iter__(self):
@@ -85,10 +90,10 @@ class _MultipleClassMarker(object):
def attempt_get(self, path, key):
if len(self.contents) > 1:
raise exc.InvalidRequestError(
- "Multiple classes found for path \"%s\" "
+ 'Multiple classes found for path "%s" '
"in the registry of this declarative "
- "base. Please use a fully module-qualified path." %
- (".".join(path + [key]))
+ "base. Please use a fully module-qualified path."
+ % (".".join(path + [key]))
)
else:
ref = list(self.contents)[0]
@@ -108,17 +113,19 @@ class _MultipleClassMarker(object):
# protect against class registration race condition against
# asynchronous garbage collection calling _remove_item,
# [ticket:3208]
- modules = set([
- cls.__module__ for cls in
- [ref() for ref in self.contents] if cls is not None])
+ modules = set(
+ [
+ cls.__module__
+ for cls in [ref() for ref in self.contents]
+ if cls is not None
+ ]
+ )
if item.__module__ in modules:
util.warn(
"This declarative base already contains a class with the "
"same class name and module name as %s.%s, and will "
- "be replaced in the string-lookup table." % (
- item.__module__,
- item.__name__
- )
+ "be replaced in the string-lookup table."
+ % (item.__module__, item.__name__)
)
self.contents.add(weakref.ref(item, self._remove_item))
@@ -129,7 +136,7 @@ class _ModuleMarker(object):
"""
- __slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__'
+ __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
def __init__(self, name, parent):
self.parent = parent
@@ -170,13 +177,13 @@ class _ModuleMarker(object):
existing = self.contents[name]
existing.add_item(cls)
else:
- existing = self.contents[name] = \
- _MultipleClassMarker([cls],
- on_remove=lambda: self._remove_item(name))
+ existing = self.contents[name] = _MultipleClassMarker(
+ [cls], on_remove=lambda: self._remove_item(name)
+ )
class _ModNS(object):
- __slots__ = '__parent',
+ __slots__ = ("__parent",)
def __init__(self, parent):
self.__parent = parent
@@ -193,13 +200,14 @@ class _ModNS(object):
else:
assert isinstance(value, _MultipleClassMarker)
return value.attempt_get(self.__parent.path, key)
- raise AttributeError("Module %r has no mapped classes "
- "registered under the name %r" % (
- self.__parent.name, key))
+ raise AttributeError(
+ "Module %r has no mapped classes "
+ "registered under the name %r" % (self.__parent.name, key)
+ )
class _GetColumns(object):
- __slots__ = 'cls',
+ __slots__ = ("cls",)
def __init__(self, cls):
self.cls = cls
@@ -210,7 +218,8 @@ class _GetColumns(object):
if key not in mp.all_orm_descriptors:
raise exc.InvalidRequestError(
"Class %r does not have a mapped column named %r"
- % (self.cls, key))
+ % (self.cls, key)
+ )
desc = mp.all_orm_descriptors[key]
if desc.extension_type is interfaces.NOT_EXTENSION:
@@ -221,24 +230,25 @@ class _GetColumns(object):
raise exc.InvalidRequestError(
"Property %r is not an instance of"
" ColumnProperty (i.e. does not correspond"
- " directly to a Column)." % key)
+ " directly to a Column)." % key
+ )
return getattr(self.cls, key)
+
inspection._inspects(_GetColumns)(
- lambda target: inspection.inspect(target.cls))
+ lambda target: inspection.inspect(target.cls)
+)
class _GetTable(object):
- __slots__ = 'key', 'metadata'
+ __slots__ = "key", "metadata"
def __init__(self, key, metadata):
self.key = key
self.metadata = metadata
def __getattr__(self, key):
- return self.metadata.tables[
- _get_table_key(key, self.key)
- ]
+ return self.metadata.tables[_get_table_key(key, self.key)]
def _determine_container(key, value):
@@ -264,9 +274,11 @@ class _class_resolver(object):
return cls.metadata.tables[key]
elif key in cls.metadata._schemas:
return _GetTable(key, cls.metadata)
- elif '_sa_module_registry' in cls._decl_class_registry and \
- key in cls._decl_class_registry['_sa_module_registry']:
- registry = cls._decl_class_registry['_sa_module_registry']
+ elif (
+ "_sa_module_registry" in cls._decl_class_registry
+ and key in cls._decl_class_registry["_sa_module_registry"]
+ ):
+ registry = cls._decl_class_registry["_sa_module_registry"]
return registry.resolve_attr(key)
elif self._resolvers:
for resolv in self._resolvers:
@@ -289,8 +301,8 @@ class _class_resolver(object):
"When initializing mapper %s, expression %r failed to "
"locate a name (%r). If this is a class name, consider "
"adding this relationship() to the %r class after "
- "both dependent classes have been defined." %
- (self.prop.parent, self.arg, n.args[0], self.cls)
+ "both dependent classes have been defined."
+ % (self.prop.parent, self.arg, n.args[0], self.cls)
)
@@ -299,10 +311,11 @@ def _resolver(cls, prop):
from sqlalchemy.orm import foreign, remote
fallback = sqlalchemy.__dict__.copy()
- fallback.update({'foreign': foreign, 'remote': remote})
+ fallback.update({"foreign": foreign, "remote": remote})
def resolve_arg(arg):
return _class_resolver(cls, prop, fallback, arg)
+
return resolve_arg
@@ -311,18 +324,32 @@ def _deferred_relationship(cls, prop):
if isinstance(prop, RelationshipProperty):
resolve_arg = _resolver(cls, prop)
- for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin',
- 'secondary', '_user_defined_foreign_keys', 'remote_side'):
+ for attr in (
+ "argument",
+ "order_by",
+ "primaryjoin",
+ "secondaryjoin",
+ "secondary",
+ "_user_defined_foreign_keys",
+ "remote_side",
+ ):
v = getattr(prop, attr)
if isinstance(v, util.string_types):
setattr(prop, attr, resolve_arg(v))
if prop.backref and isinstance(prop.backref, tuple):
key, kwargs = prop.backref
- for attr in ('primaryjoin', 'secondaryjoin', 'secondary',
- 'foreign_keys', 'remote_side', 'order_by'):
- if attr in kwargs and isinstance(kwargs[attr],
- util.string_types):
+ for attr in (
+ "primaryjoin",
+ "secondaryjoin",
+ "secondary",
+ "foreign_keys",
+ "remote_side",
+ "order_by",
+ ):
+ if attr in kwargs and isinstance(
+ kwargs[attr], util.string_types
+ ):
kwargs[attr] = resolve_arg(kwargs[attr])
return prop
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index f86e4fc93..7248e5b4d 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -20,7 +20,7 @@ from .. import util
from ..orm.session import Session
from ..orm.query import Query
-__all__ = ['ShardedSession', 'ShardedQuery']
+__all__ = ["ShardedSession", "ShardedQuery"]
class ShardedQuery(Query):
@@ -43,12 +43,10 @@ class ShardedQuery(Query):
def _execute_and_instances(self, context):
def iter_for_shard(shard_id):
- context.attributes['shard_id'] = context.identity_token = shard_id
+ context.attributes["shard_id"] = context.identity_token = shard_id
result = self._connection_from_session(
- mapper=self._bind_mapper(),
- shard_id=shard_id).execute(
- context.statement,
- self._params)
+ mapper=self._bind_mapper(), shard_id=shard_id
+ ).execute(context.statement, self._params)
return self.instances(result, context)
if context.identity_token is not None:
@@ -70,7 +68,8 @@ class ShardedQuery(Query):
mapper=mapper,
shard_id=shard_id,
clause=stmt,
- close_with_result=True)
+ close_with_result=True,
+ )
result = conn.execute(stmt, self._params)
return result
@@ -87,8 +86,13 @@ class ShardedQuery(Query):
return ShardedResult(results, rowcount)
def _identity_lookup(
- self, mapper, primary_key_identity, identity_token=None,
- lazy_loaded_from=None, **kw):
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ lazy_loaded_from=None,
+ **kw
+ ):
"""override the default Query._identity_lookup method so that we
search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
@@ -97,8 +101,10 @@ class ShardedQuery(Query):
if identity_token is not None:
return super(ShardedQuery, self)._identity_lookup(
- mapper, primary_key_identity,
- identity_token=identity_token, **kw
+ mapper,
+ primary_key_identity,
+ identity_token=identity_token,
+ **kw
)
else:
q = self.session.query(mapper)
@@ -113,13 +119,13 @@ class ShardedQuery(Query):
return None
- def _get_impl(
- self, primary_key_identity, db_load_fn, identity_token=None):
+ def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
"""Override the default Query._get_impl() method so that we emit
a query to the DB for each possible identity token, if we don't
have one already.
"""
+
def _db_load_fn(query, primary_key_identity):
# load from the database. The original db_load_fn will
# use the given Query object to load from the DB, so our
@@ -142,7 +148,8 @@ class ShardedQuery(Query):
identity_token = self._shard_id
return super(ShardedQuery, self)._get_impl(
- primary_key_identity, _db_load_fn, identity_token=identity_token)
+ primary_key_identity, _db_load_fn, identity_token=identity_token
+ )
class ShardedResult(object):
@@ -158,7 +165,7 @@ class ShardedResult(object):
.. versionadded:: 1.3
"""
- __slots__ = ('result_proxies', 'aggregate_rowcount',)
+ __slots__ = ("result_proxies", "aggregate_rowcount")
def __init__(self, result_proxies, aggregate_rowcount):
self.result_proxies = result_proxies
@@ -168,9 +175,17 @@ class ShardedResult(object):
def rowcount(self):
return self.aggregate_rowcount
+
class ShardedSession(Session):
- def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None,
- query_cls=ShardedQuery, **kwargs):
+ def __init__(
+ self,
+ shard_chooser,
+ id_chooser,
+ query_chooser,
+ shards=None,
+ query_cls=ShardedQuery,
+ **kwargs
+ ):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
@@ -225,16 +240,16 @@ class ShardedSession(Session):
return self.transaction.connection(mapper, shard_id=shard_id)
else:
return self.get_bind(
- mapper,
- shard_id=shard_id,
- instance=instance
+ mapper, shard_id=shard_id, instance=instance
).contextual_connect(**kwargs)
- def get_bind(self, mapper, shard_id=None,
- instance=None, clause=None, **kw):
+ def get_bind(
+ self, mapper, shard_id=None, instance=None, clause=None, **kw
+ ):
if shard_id is None:
shard_id = self._choose_shard_and_assign(
- mapper, instance, clause=clause)
+ mapper, instance, clause=clause
+ )
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index 95eecb93f..d51a083da 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -778,7 +778,7 @@ there's probably a whole lot of amazing things it can be used for.
from .. import util
from ..orm import attributes, interfaces
-HYBRID_METHOD = util.symbol('HYBRID_METHOD')
+HYBRID_METHOD = util.symbol("HYBRID_METHOD")
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.hybrid_method`.
@@ -791,7 +791,7 @@ HYBRID_METHOD = util.symbol('HYBRID_METHOD')
"""
-HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY')
+HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY")
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.hybrid_method`.
@@ -860,8 +860,14 @@ class hybrid_property(interfaces.InspectionAttrInfo):
extension_type = HYBRID_PROPERTY
def __init__(
- self, fget, fset=None, fdel=None,
- expr=None, custom_comparator=None, update_expr=None):
+ self,
+ fget,
+ fset=None,
+ fdel=None,
+ expr=None,
+ custom_comparator=None,
+ update_expr=None,
+ ):
"""Create a new :class:`.hybrid_property`.
Usage is typically via decorator::
@@ -906,7 +912,8 @@ class hybrid_property(interfaces.InspectionAttrInfo):
defaults = {
key: value
for key, value in self.__dict__.items()
- if not key.startswith("_")}
+ if not key.startswith("_")
+ }
defaults.update(**kw)
return type(self)(**defaults)
@@ -1078,9 +1085,9 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return self._get_expr(self.fget)
def _get_expr(self, expr):
-
def _expr(cls):
return ExprComparator(cls, expr(cls), self)
+
util.update_wrapper(_expr, expr)
return self._get_comparator(_expr)
@@ -1091,8 +1098,13 @@ class hybrid_property(interfaces.InspectionAttrInfo):
def expr_comparator(owner):
return proxy_attr(
- owner, self.__name__, self, comparator(owner),
- doc=comparator.__doc__ or self.__doc__)
+ owner,
+ self.__name__,
+ self,
+ comparator(owner),
+ doc=comparator.__doc__ or self.__doc__,
+ )
+
return expr_comparator
@@ -1108,7 +1120,7 @@ class Comparator(interfaces.PropComparator):
def __clause_element__(self):
expr = self.expression
- if hasattr(expr, '__clause_element__'):
+ if hasattr(expr, "__clause_element__"):
expr = expr.__clause_element__()
return expr
diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py
index 0bc2b65bb..368e5b00a 100644
--- a/lib/sqlalchemy/ext/indexable.py
+++ b/lib/sqlalchemy/ext/indexable.py
@@ -232,7 +232,7 @@ from ..orm.attributes import flag_modified
from ..ext.hybrid import hybrid_property
-__all__ = ['index_property']
+__all__ = ["index_property"]
class index_property(hybrid_property): # noqa
@@ -251,8 +251,14 @@ class index_property(hybrid_property): # noqa
_NO_DEFAULT_ARGUMENT = object()
def __init__(
- self, attr_name, index, default=_NO_DEFAULT_ARGUMENT,
- datatype=None, mutable=True, onebased=True):
+ self,
+ attr_name,
+ index,
+ default=_NO_DEFAULT_ARGUMENT,
+ datatype=None,
+ mutable=True,
+ onebased=True,
+ ):
"""Create a new :class:`.index_property`.
:param attr_name:
diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py
index 30a0ab7d7..b2b8dd7c5 100644
--- a/lib/sqlalchemy/ext/instrumentation.py
+++ b/lib/sqlalchemy/ext/instrumentation.py
@@ -28,15 +28,18 @@ see the example :ref:`examples_instrumentation`.
"""
from ..orm import instrumentation as orm_instrumentation
from ..orm.instrumentation import (
- ClassManager, InstrumentationFactory, _default_state_getter,
- _default_dict_getter, _default_manager_getter
+ ClassManager,
+ InstrumentationFactory,
+ _default_state_getter,
+ _default_dict_getter,
+ _default_manager_getter,
)
from ..orm import attributes, collections, base as orm_base
from .. import util
from ..orm import exc as orm_exc
import weakref
-INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__'
+INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
"""Attribute, elects custom instrumentation when present on a mapped class.
Allows a class to specify a slightly or wildly different technique for
@@ -66,6 +69,7 @@ def find_native_user_instrumentation_hook(cls):
"""Find user-specified instrumentation management for a class."""
return getattr(cls, INSTRUMENTATION_MANAGER, None)
+
instrumentation_finders = [find_native_user_instrumentation_hook]
"""An extensible sequence of callables which return instrumentation
implementations
@@ -89,6 +93,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
class managers.
"""
+
_manager_finders = weakref.WeakKeyDictionary()
_state_finders = weakref.WeakKeyDictionary()
_dict_finders = weakref.WeakKeyDictionary()
@@ -104,13 +109,15 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
return None, None
def _check_conflicts(self, class_, factory):
- existing_factories = self._collect_management_factories_for(class_).\
- difference([factory])
+ existing_factories = self._collect_management_factories_for(
+ class_
+ ).difference([factory])
if existing_factories:
raise TypeError(
"multiple instrumentation implementations specified "
- "in %s inheritance hierarchy: %r" % (
- class_.__name__, list(existing_factories)))
+ "in %s inheritance hierarchy: %r"
+ % (class_.__name__, list(existing_factories))
+ )
def _extended_class_manager(self, class_, factory):
manager = factory(class_)
@@ -178,17 +185,20 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._state_finders.get(
- instance.__class__, _default_state_getter)(instance)
+ instance.__class__, _default_state_getter
+ )(instance)
def dict_of(self, instance):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._dict_finders.get(
- instance.__class__, _default_dict_getter)(instance)
+ instance.__class__, _default_dict_getter
+ )(instance)
-orm_instrumentation._instrumentation_factory = \
- _instrumentation_factory = ExtendedInstrumentationRegistry()
+orm_instrumentation._instrumentation_factory = (
+ _instrumentation_factory
+) = ExtendedInstrumentationRegistry()
orm_instrumentation.instrumentation_finders = instrumentation_finders
@@ -222,14 +232,15 @@ class InstrumentationManager(object):
pass
def manage(self, class_, manager):
- setattr(class_, '_default_class_manager', manager)
+ setattr(class_, "_default_class_manager", manager)
def dispose(self, class_, manager):
- delattr(class_, '_default_class_manager')
+ delattr(class_, "_default_class_manager")
def manager_getter(self, class_):
def get(cls):
return cls._default_class_manager
+
return get
def instrument_attribute(self, class_, key, inst):
@@ -260,13 +271,13 @@ class InstrumentationManager(object):
pass
def install_state(self, class_, instance, state):
- setattr(instance, '_default_state', state)
+ setattr(instance, "_default_state", state)
def remove_state(self, class_, instance):
- delattr(instance, '_default_state')
+ delattr(instance, "_default_state")
def state_getter(self, class_):
- return lambda instance: getattr(instance, '_default_state')
+ return lambda instance: getattr(instance, "_default_state")
def dict_getter(self, class_):
return lambda inst: self.get_instance_dict(class_, inst)
@@ -314,15 +325,17 @@ class _ClassInstrumentationAdapter(ClassManager):
def instrument_collection_class(self, key, collection_class):
return self._adapted.instrument_collection_class(
- self.class_, key, collection_class)
+ self.class_, key, collection_class
+ )
def initialize_collection(self, key, state, factory):
- delegate = getattr(self._adapted, 'initialize_collection', None)
+ delegate = getattr(self._adapted, "initialize_collection", None)
if delegate:
return delegate(key, state, factory)
else:
- return ClassManager.initialize_collection(self, key,
- state, factory)
+ return ClassManager.initialize_collection(
+ self, key, state, factory
+ )
def new_instance(self, state=None):
instance = self.class_.__new__(self.class_)
@@ -384,7 +397,7 @@ def _install_instrumented_lookups():
dict(
instance_state=_instrumentation_factory.state_of,
instance_dict=_instrumentation_factory.dict_of,
- manager_of_class=_instrumentation_factory.manager_of_class
+ manager_of_class=_instrumentation_factory.manager_of_class,
)
)
@@ -395,7 +408,7 @@ def _reinstall_default_lookups():
dict(
instance_state=_default_state_getter,
instance_dict=_default_dict_getter,
- manager_of_class=_default_manager_getter
+ manager_of_class=_default_manager_getter,
)
)
_instrumentation_factory._extended = False
@@ -403,12 +416,15 @@ def _reinstall_default_lookups():
def _install_lookups(lookups):
global instance_state, instance_dict, manager_of_class
- instance_state = lookups['instance_state']
- instance_dict = lookups['instance_dict']
- manager_of_class = lookups['manager_of_class']
- orm_base.instance_state = attributes.instance_state = \
- orm_instrumentation.instance_state = instance_state
- orm_base.instance_dict = attributes.instance_dict = \
- orm_instrumentation.instance_dict = instance_dict
- orm_base.manager_of_class = attributes.manager_of_class = \
- orm_instrumentation.manager_of_class = manager_of_class
+ instance_state = lookups["instance_state"]
+ instance_dict = lookups["instance_dict"]
+ manager_of_class = lookups["manager_of_class"]
+ orm_base.instance_state = (
+ attributes.instance_state
+ ) = orm_instrumentation.instance_state = instance_state
+ orm_base.instance_dict = (
+ attributes.instance_dict
+ ) = orm_instrumentation.instance_dict = instance_dict
+ orm_base.manager_of_class = (
+ attributes.manager_of_class
+ ) = orm_instrumentation.manager_of_class = manager_of_class
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py
index 014cef3cc..0f6ccdc33 100644
--- a/lib/sqlalchemy/ext/mutable.py
+++ b/lib/sqlalchemy/ext/mutable.py
@@ -502,27 +502,29 @@ class MutableBase(object):
def pickle(state, state_dict):
val = state.dict.get(key, None)
if val is not None:
- if 'ext.mutable.values' not in state_dict:
- state_dict['ext.mutable.values'] = []
- state_dict['ext.mutable.values'].append(val)
+ if "ext.mutable.values" not in state_dict:
+ state_dict["ext.mutable.values"] = []
+ state_dict["ext.mutable.values"].append(val)
def unpickle(state, state_dict):
- if 'ext.mutable.values' in state_dict:
- for val in state_dict['ext.mutable.values']:
+ if "ext.mutable.values" in state_dict:
+ for val in state_dict["ext.mutable.values"]:
val._parents[state.obj()] = key
- event.listen(parent_cls, 'load', load,
- raw=True, propagate=True)
- event.listen(parent_cls, 'refresh', load_attrs,
- raw=True, propagate=True)
- event.listen(parent_cls, 'refresh_flush', load_attrs,
- raw=True, propagate=True)
- event.listen(attribute, 'set', set,
- raw=True, retval=True, propagate=True)
- event.listen(parent_cls, 'pickle', pickle,
- raw=True, propagate=True)
- event.listen(parent_cls, 'unpickle', unpickle,
- raw=True, propagate=True)
+ event.listen(parent_cls, "load", load, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "refresh", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ attribute, "set", set, raw=True, retval=True, propagate=True
+ )
+ event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "unpickle", unpickle, raw=True, propagate=True
+ )
class Mutable(MutableBase):
@@ -572,7 +574,7 @@ class Mutable(MutableBase):
if isinstance(prop.columns[0].type, sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
- event.listen(mapper, 'mapper_configured', listen_for_type)
+ event.listen(mapper, "mapper_configured", listen_for_type)
@classmethod
def as_mutable(cls, sqltype):
@@ -613,9 +615,11 @@ class Mutable(MutableBase):
# and we'll lose our ability to link that type back to the original.
# so track our original type w/ columns
if isinstance(sqltype, SchemaEventTarget):
+
@event.listens_for(sqltype, "before_parent_attach")
def _add_column_memo(sqltyp, parent):
- parent.info['_ext_mutable_orig_type'] = sqltyp
+ parent.info["_ext_mutable_orig_type"] = sqltyp
+
schema_event_check = True
else:
schema_event_check = False
@@ -625,16 +629,14 @@ class Mutable(MutableBase):
return
for prop in mapper.column_attrs:
if (
- schema_event_check and
- hasattr(prop.expression, 'info') and
- prop.expression.info.get('_ext_mutable_orig_type')
- is sqltype
- ) or (
- prop.columns[0].type is sqltype
- ):
+ schema_event_check
+ and hasattr(prop.expression, "info")
+ and prop.expression.info.get("_ext_mutable_orig_type")
+ is sqltype
+ ) or (prop.columns[0].type is sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
- event.listen(mapper, 'mapper_configured', listen_for_type)
+ event.listen(mapper, "mapper_configured", listen_for_type)
return sqltype
@@ -659,21 +661,27 @@ class MutableComposite(MutableBase):
prop = object_mapper(parent).get_property(key)
for value, attr_name in zip(
- self.__composite_values__(),
- prop._attribute_keys):
+ self.__composite_values__(), prop._attribute_keys
+ ):
setattr(parent, attr_name, value)
def _setup_composite_listener():
def _listen_for_type(mapper, class_):
for prop in mapper.iterate_properties:
- if (hasattr(prop, 'composite_class') and
- isinstance(prop.composite_class, type) and
- issubclass(prop.composite_class, MutableComposite)):
+ if (
+ hasattr(prop, "composite_class")
+ and isinstance(prop.composite_class, type)
+ and issubclass(prop.composite_class, MutableComposite)
+ ):
prop.composite_class._listen_on_attribute(
- getattr(class_, prop.key), False, class_)
+ getattr(class_, prop.key), False, class_
+ )
+
if not event.contains(Mapper, "mapper_configured", _listen_for_type):
- event.listen(Mapper, 'mapper_configured', _listen_for_type)
+ event.listen(Mapper, "mapper_configured", _listen_for_type)
+
+
_setup_composite_listener()
@@ -947,4 +955,4 @@ class MutableSet(Mutable, set):
self.update(state)
def __reduce_ex__(self, proto):
- return (self.__class__, (list(self), ))
+ return (self.__class__, (list(self),))
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
index 316742a67..2a8522120 100644
--- a/lib/sqlalchemy/ext/orderinglist.py
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -122,7 +122,7 @@ start numbering at 1 or some other integer, provide ``count_from=1``.
from ..orm.collections import collection, collection_adapter
from .. import util
-__all__ = ['ordering_list']
+__all__ = ["ordering_list"]
def ordering_list(attr, count_from=None, **kw):
@@ -180,8 +180,9 @@ def count_from_n_factory(start):
def f(index, collection):
return index + start
+
try:
- f.__name__ = 'count_from_%i' % start
+ f.__name__ = "count_from_%i" % start
except TypeError:
pass
return f
@@ -194,14 +195,14 @@ def _unsugar_count_from(**kw):
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
"""
- count_from = kw.pop('count_from', None)
- if kw.get('ordering_func', None) is None and count_from is not None:
+ count_from = kw.pop("count_from", None)
+ if kw.get("ordering_func", None) is None and count_from is not None:
if count_from == 0:
- kw['ordering_func'] = count_from_0
+ kw["ordering_func"] = count_from_0
elif count_from == 1:
- kw['ordering_func'] = count_from_1
+ kw["ordering_func"] = count_from_1
else:
- kw['ordering_func'] = count_from_n_factory(count_from)
+ kw["ordering_func"] = count_from_n_factory(count_from)
return kw
@@ -214,8 +215,9 @@ class OrderingList(list):
"""
- def __init__(self, ordering_attr=None, ordering_func=None,
- reorder_on_append=False):
+ def __init__(
+ self, ordering_attr=None, ordering_func=None, reorder_on_append=False
+ ):
"""A custom list that manages position information for its children.
``OrderingList`` is a ``collection_class`` list implementation that
@@ -311,6 +313,7 @@ class OrderingList(list):
"""Append without any ordering behavior."""
super(OrderingList, self).append(entity)
+
_raw_append = collection.adds(1)(_raw_append)
def insert(self, index, entity):
@@ -361,8 +364,12 @@ class OrderingList(list):
return _reconstitute, (self.__class__, self.__dict__, list(self))
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(list, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py
index 2fded51d1..3adcec34f 100644
--- a/lib/sqlalchemy/ext/serializer.py
+++ b/lib/sqlalchemy/ext/serializer.py
@@ -64,7 +64,7 @@ from ..util import pickle, byte_buffer, b64encode, b64decode, text_type
import re
-__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
+__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
def Serializer(*args, **kw):
@@ -79,13 +79,18 @@ def Serializer(*args, **kw):
elif isinstance(obj, Mapper) and not obj.non_primary:
id = "mapper:" + b64encode(pickle.dumps(obj.class_))
elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
- id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \
- ":" + obj.key
+ id = (
+ "mapperprop:"
+ + b64encode(pickle.dumps(obj.parent.class_))
+ + ":"
+ + obj.key
+ )
elif isinstance(obj, Table):
id = "table:" + text_type(obj.key)
elif isinstance(obj, Column) and isinstance(obj.table, Table):
- id = "column:" + \
- text_type(obj.table.key) + ":" + text_type(obj.key)
+ id = (
+ "column:" + text_type(obj.table.key) + ":" + text_type(obj.key)
+ )
elif isinstance(obj, Session):
id = "session:"
elif isinstance(obj, Engine):
@@ -97,8 +102,10 @@ def Serializer(*args, **kw):
pickler.persistent_id = persistent_id
return pickler
+
our_ids = re.compile(
- r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)')
+ r"(mapperprop|mapper|table|column|session|attribute|engine):(.*)"
+)
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
@@ -120,7 +127,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
return None
else:
type_, args = m.group(1, 2)
- if type_ == 'attribute':
+ if type_ == "attribute":
key, clsarg = args.split(":")
cls = pickle.loads(b64decode(clsarg))
return getattr(cls, key)
@@ -128,13 +135,13 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
cls = pickle.loads(b64decode(args))
return class_mapper(cls)
elif type_ == "mapperprop":
- mapper, keyname = args.split(':')
+ mapper, keyname = args.split(":")
cls = pickle.loads(b64decode(mapper))
return class_mapper(cls).attrs[keyname]
elif type_ == "table":
return metadata.tables[args]
elif type_ == "column":
- table, colname = args.split(':')
+ table, colname = args.split(":")
return metadata.tables[table].c[colname]
elif type_ == "session":
return scoped_session()
@@ -142,6 +149,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
return get_engine()
else:
raise Exception("Unknown token: %s" % type_)
+
unpickler.persistent_load = persistent_load
return unpickler
diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py
index 3a03e2507..7c2ff97c5 100644
--- a/lib/sqlalchemy/inspection.py
+++ b/lib/sqlalchemy/inspection.py
@@ -32,6 +32,7 @@ in a forwards-compatible way.
"""
from . import util, exc
+
_registrars = util.defaultdict(list)
@@ -66,13 +67,11 @@ def inspect(subject, raiseerr=True):
else:
reg = ret = None
- if raiseerr and (
- reg is None or ret is None
- ):
+ if raiseerr and (reg is None or ret is None):
raise exc.NoInspectionAvailable(
"No inspection system is "
- "available for object of type %s" %
- type_)
+ "available for object of type %s" % type_
+ )
return ret
@@ -81,10 +80,11 @@ def _inspects(*types):
for type_ in types:
if type_ in _registrars:
raise AssertionError(
- "Type %s is already "
- "registered" % type_)
+ "Type %s is already " "registered" % type_
+ )
_registrars[type_] = fn_or_cls
return fn_or_cls
+
return decorate
diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py
index 30698ea33..f352f7f26 100644
--- a/lib/sqlalchemy/interfaces.py
+++ b/lib/sqlalchemy/interfaces.py
@@ -80,17 +80,18 @@ class PoolListener(object):
"""
- listener = util.as_interface(listener,
- methods=('connect', 'first_connect',
- 'checkout', 'checkin'))
- if hasattr(listener, 'connect'):
- event.listen(self, 'connect', listener.connect)
- if hasattr(listener, 'first_connect'):
- event.listen(self, 'first_connect', listener.first_connect)
- if hasattr(listener, 'checkout'):
- event.listen(self, 'checkout', listener.checkout)
- if hasattr(listener, 'checkin'):
- event.listen(self, 'checkin', listener.checkin)
+ listener = util.as_interface(
+ listener,
+ methods=("connect", "first_connect", "checkout", "checkin"),
+ )
+ if hasattr(listener, "connect"):
+ event.listen(self, "connect", listener.connect)
+ if hasattr(listener, "first_connect"):
+ event.listen(self, "first_connect", listener.first_connect)
+ if hasattr(listener, "checkout"):
+ event.listen(self, "checkout", listener.checkout)
+ if hasattr(listener, "checkin"):
+ event.listen(self, "checkin", listener.checkin)
def connect(self, dbapi_con, con_record):
"""Called once for each new DB-API connection or Pool's ``creator()``.
@@ -187,27 +188,20 @@ class ConnectionProxy(object):
@classmethod
def _adapt_listener(cls, self, listener):
-
def adapt_execute(conn, clauseelement, multiparams, params):
-
def execute_wrapper(clauseelement, *multiparams, **params):
return clauseelement, multiparams, params
- return listener.execute(conn, execute_wrapper,
- clauseelement, *multiparams,
- **params)
+ return listener.execute(
+ conn, execute_wrapper, clauseelement, *multiparams, **params
+ )
- event.listen(self, 'before_execute', adapt_execute)
+ event.listen(self, "before_execute", adapt_execute)
- def adapt_cursor_execute(conn, cursor, statement,
- parameters, context, executemany):
-
- def execute_wrapper(
- cursor,
- statement,
- parameters,
- context,
- ):
+ def adapt_cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ def execute_wrapper(cursor, statement, parameters, context):
return statement, parameters
return listener.cursor_execute(
@@ -217,46 +211,56 @@ class ConnectionProxy(object):
parameters,
context,
executemany,
- )
+ )
- event.listen(self, 'before_cursor_execute', adapt_cursor_execute)
+ event.listen(self, "before_cursor_execute", adapt_cursor_execute)
def do_nothing_callback(*arg, **kw):
pass
def adapt_listener(fn):
-
def go(conn, *arg, **kw):
fn(conn, do_nothing_callback, *arg, **kw)
return util.update_wrapper(go, fn)
- event.listen(self, 'begin', adapt_listener(listener.begin))
- event.listen(self, 'rollback',
- adapt_listener(listener.rollback))
- event.listen(self, 'commit', adapt_listener(listener.commit))
- event.listen(self, 'savepoint',
- adapt_listener(listener.savepoint))
- event.listen(self, 'rollback_savepoint',
- adapt_listener(listener.rollback_savepoint))
- event.listen(self, 'release_savepoint',
- adapt_listener(listener.release_savepoint))
- event.listen(self, 'begin_twophase',
- adapt_listener(listener.begin_twophase))
- event.listen(self, 'prepare_twophase',
- adapt_listener(listener.prepare_twophase))
- event.listen(self, 'rollback_twophase',
- adapt_listener(listener.rollback_twophase))
- event.listen(self, 'commit_twophase',
- adapt_listener(listener.commit_twophase))
+ event.listen(self, "begin", adapt_listener(listener.begin))
+ event.listen(self, "rollback", adapt_listener(listener.rollback))
+ event.listen(self, "commit", adapt_listener(listener.commit))
+ event.listen(self, "savepoint", adapt_listener(listener.savepoint))
+ event.listen(
+ self,
+ "rollback_savepoint",
+ adapt_listener(listener.rollback_savepoint),
+ )
+ event.listen(
+ self,
+ "release_savepoint",
+ adapt_listener(listener.release_savepoint),
+ )
+ event.listen(
+ self, "begin_twophase", adapt_listener(listener.begin_twophase)
+ )
+ event.listen(
+ self, "prepare_twophase", adapt_listener(listener.prepare_twophase)
+ )
+ event.listen(
+ self,
+ "rollback_twophase",
+ adapt_listener(listener.rollback_twophase),
+ )
+ event.listen(
+ self, "commit_twophase", adapt_listener(listener.commit_twophase)
+ )
def execute(self, conn, execute, clauseelement, *multiparams, **params):
"""Intercept high level execute() events."""
return execute(clauseelement, *multiparams, **params)
- def cursor_execute(self, execute, cursor, statement, parameters,
- context, executemany):
+ def cursor_execute(
+ self, execute, cursor, statement, parameters, context, executemany
+ ):
"""Intercept low-level cursor execute() events."""
return execute(cursor, statement, parameters, context)
diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py
index a79b21e17..6b0b2e90e 100644
--- a/lib/sqlalchemy/log.py
+++ b/lib/sqlalchemy/log.py
@@ -24,15 +24,16 @@ import sys
# set initial level to WARN. This so that
# log statements don't occur in the absence of explicit
# logging being enabled for 'sqlalchemy'.
-rootlogger = logging.getLogger('sqlalchemy')
+rootlogger = logging.getLogger("sqlalchemy")
if rootlogger.level == logging.NOTSET:
rootlogger.setLevel(logging.WARN)
def _add_default_handler(logger):
handler = logging.StreamHandler(sys.stdout)
- handler.setFormatter(logging.Formatter(
- '%(asctime)s %(levelname)s %(name)s %(message)s'))
+ handler.setFormatter(
+ logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
+ )
logger.addHandler(handler)
@@ -82,7 +83,7 @@ class InstanceLogger(object):
None: logging.NOTSET,
False: logging.NOTSET,
True: logging.INFO,
- 'debug': logging.DEBUG,
+ "debug": logging.DEBUG,
}
def __init__(self, echo, name):
@@ -91,8 +92,7 @@ class InstanceLogger(object):
# if echo flag is enabled and no handlers,
# add a handler to the list
- if self._echo_map[echo] <= logging.INFO \
- and not self.logger.handlers:
+ if self._echo_map[echo] <= logging.INFO and not self.logger.handlers:
_add_default_handler(self.logger)
#
@@ -174,12 +174,16 @@ def instance_logger(instance, echoflag=None):
"""create a logger for an instance that implements :class:`.Identified`."""
if instance.logging_name:
- name = "%s.%s.%s" % (instance.__class__.__module__,
- instance.__class__.__name__,
- instance.logging_name)
+ name = "%s.%s.%s" % (
+ instance.__class__.__module__,
+ instance.__class__.__name__,
+ instance.logging_name,
+ )
else:
- name = "%s.%s" % (instance.__class__.__module__,
- instance.__class__.__name__)
+ name = "%s.%s" % (
+ instance.__class__.__module__,
+ instance.__class__.__name__,
+ )
instance._echo = echoflag
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 1784ea21f..8e7b4cee6 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -20,14 +20,9 @@ from .mapper import (
class_mapper,
configure_mappers,
reconstructor,
- validates
-)
-from .interfaces import (
- EXT_CONTINUE,
- EXT_STOP,
- EXT_SKIP,
- PropComparator,
+ validates,
)
+from .interfaces import EXT_CONTINUE, EXT_STOP, EXT_SKIP, PropComparator
from .deprecated_interfaces import (
MapperExtension,
SessionExtension,
@@ -50,20 +45,15 @@ from .descriptor_props import (
CompositeProperty,
SynonymProperty,
)
-from .relationships import (
- foreign,
- remote,
-)
+from .relationships import foreign, remote
from .session import (
Session,
object_session,
sessionmaker,
make_transient,
- make_transient_to_detached
-)
-from .scoping import (
- scoped_session
+ make_transient_to_detached,
)
+from .scoping import scoped_session
from . import mapper as mapperlib
from .query import AliasOption, Query, Bundle
from ..util.langhelpers import public_factory
@@ -103,11 +93,12 @@ def create_session(bind=None, **kwargs):
create_session().
"""
- kwargs.setdefault('autoflush', False)
- kwargs.setdefault('autocommit', True)
- kwargs.setdefault('expire_on_commit', False)
+ kwargs.setdefault("autoflush", False)
+ kwargs.setdefault("autocommit", True)
+ kwargs.setdefault("expire_on_commit", False)
return Session(bind=bind, **kwargs)
+
relationship = public_factory(RelationshipProperty, ".orm.relationship")
@@ -133,7 +124,7 @@ def dynamic_loader(argument, **kw):
on dynamic loading.
"""
- kw['lazy'] = 'dynamic'
+ kw["lazy"] = "dynamic"
return relationship(argument, **kw)
@@ -193,16 +184,21 @@ def query_expression():
prop.strategy_key = (("query_expression", True),)
return prop
+
mapper = public_factory(Mapper, ".orm.mapper")
synonym = public_factory(SynonymProperty, ".orm.synonym")
-comparable_property = public_factory(ComparableProperty,
- ".orm.comparable_property")
+comparable_property = public_factory(
+ ComparableProperty, ".orm.comparable_property"
+)
-@_sa_util.deprecated("0.7", message=":func:`.compile_mappers` "
- "is renamed to :func:`.configure_mappers`")
+@_sa_util.deprecated(
+ "0.7",
+ message=":func:`.compile_mappers` "
+ "is renamed to :func:`.configure_mappers`",
+)
def compile_mappers():
"""Initialize the inter-mapper relationships of all mappers that have
been defined.
@@ -243,6 +239,7 @@ def clear_mappers():
finally:
mapperlib._CONFIGURE_MUTEX.release()
+
from . import strategy_options
joinedload = strategy_options.joinedload._unbound_fn
@@ -289,10 +286,14 @@ def __go(lcls):
from . import loading
import inspect as _inspect
- __all__ = sorted(name for name, obj in lcls.items()
- if not (name.startswith('_') or _inspect.ismodule(obj)))
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
_sa_util.dependencies.resolve_all("sqlalchemy.orm")
_sa_util.dependencies.resolve_all("sqlalchemy.ext")
+
__go(locals())
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index b08c46741..1648c9ae1 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -20,19 +20,37 @@ from . import interfaces, collections, exc as orm_exc
from .base import instance_state, instance_dict, manager_of_class
-from .base import PASSIVE_NO_RESULT, ATTR_WAS_SET, ATTR_EMPTY, NO_VALUE,\
- NEVER_SET, NO_CHANGE, CALLABLES_OK, SQL_OK, RELATED_OBJECT_OK,\
- INIT_OK, NON_PERSISTENT_OK, LOAD_AGAINST_COMMITTED, PASSIVE_OFF,\
- PASSIVE_RETURN_NEVER_SET, PASSIVE_NO_INITIALIZE, PASSIVE_NO_FETCH,\
- PASSIVE_NO_FETCH_RELATED, PASSIVE_ONLY_PERSISTENT, NO_AUTOFLUSH, \
- NO_RAISE
+from .base import (
+ PASSIVE_NO_RESULT,
+ ATTR_WAS_SET,
+ ATTR_EMPTY,
+ NO_VALUE,
+ NEVER_SET,
+ NO_CHANGE,
+ CALLABLES_OK,
+ SQL_OK,
+ RELATED_OBJECT_OK,
+ INIT_OK,
+ NON_PERSISTENT_OK,
+ LOAD_AGAINST_COMMITTED,
+ PASSIVE_OFF,
+ PASSIVE_RETURN_NEVER_SET,
+ PASSIVE_NO_INITIALIZE,
+ PASSIVE_NO_FETCH,
+ PASSIVE_NO_FETCH_RELATED,
+ PASSIVE_ONLY_PERSISTENT,
+ NO_AUTOFLUSH,
+ NO_RAISE,
+)
from .base import state_str, instance_str
@inspection._self_inspects
-class QueryableAttribute(interfaces._MappedAttribute,
- interfaces.InspectionAttr,
- interfaces.PropComparator):
+class QueryableAttribute(
+ interfaces._MappedAttribute,
+ interfaces.InspectionAttr,
+ interfaces.PropComparator,
+):
"""Base class for :term:`descriptor` objects that intercept
attribute events on behalf of a :class:`.MapperProperty`
object. The actual :class:`.MapperProperty` is accessible
@@ -53,9 +71,15 @@ class QueryableAttribute(interfaces._MappedAttribute,
is_attribute = True
- def __init__(self, class_, key, impl=None,
- comparator=None, parententity=None,
- of_type=None):
+ def __init__(
+ self,
+ class_,
+ key,
+ impl=None,
+ comparator=None,
+ parententity=None,
+ of_type=None,
+ ):
self.class_ = class_
self.key = key
self.impl = impl
@@ -77,8 +101,9 @@ class QueryableAttribute(interfaces._MappedAttribute,
return self.impl.supports_population
def get_history(self, instance, passive=PASSIVE_OFF):
- return self.impl.get_history(instance_state(instance),
- instance_dict(instance), passive)
+ return self.impl.get_history(
+ instance_state(instance), instance_dict(instance), passive
+ )
def __selectable__(self):
# TODO: conditionally attach this method based on clause_element ?
@@ -159,11 +184,13 @@ class QueryableAttribute(interfaces._MappedAttribute,
def adapt_to_entity(self, adapt_to_entity):
assert not self._of_type
- return self.__class__(adapt_to_entity.entity,
- self.key, impl=self.impl,
- comparator=self.comparator.adapt_to_entity(
- adapt_to_entity),
- parententity=adapt_to_entity)
+ return self.__class__(
+ adapt_to_entity.entity,
+ self.key,
+ impl=self.impl,
+ comparator=self.comparator.adapt_to_entity(adapt_to_entity),
+ parententity=adapt_to_entity,
+ )
def of_type(self, cls):
return QueryableAttribute(
@@ -172,7 +199,8 @@ class QueryableAttribute(interfaces._MappedAttribute,
self.impl,
self.comparator.of_type(cls),
self._parententity,
- of_type=cls)
+ of_type=cls,
+ )
def label(self, name):
return self._query_clause_element().label(name)
@@ -191,12 +219,14 @@ class QueryableAttribute(interfaces._MappedAttribute,
return getattr(self.comparator, key)
except AttributeError:
raise AttributeError(
- 'Neither %r object nor %r object associated with %s '
- 'has an attribute %r' % (
+ "Neither %r object nor %r object associated with %s "
+ "has an attribute %r"
+ % (
type(self).__name__,
type(self.comparator).__name__,
self,
- key)
+ key,
+ )
)
def __str__(self):
@@ -226,8 +256,9 @@ class InstrumentedAttribute(QueryableAttribute):
"""
def __set__(self, instance, value):
- self.impl.set(instance_state(instance),
- instance_dict(instance), value, None)
+ self.impl.set(
+ instance_state(instance), instance_dict(instance), value, None
+ )
def __delete__(self, instance):
self.impl.delete(instance_state(instance), instance_dict(instance))
@@ -260,10 +291,16 @@ def create_proxied_attribute(descriptor):
"""
- def __init__(self, class_, key, descriptor,
- comparator,
- adapt_to_entity=None, doc=None,
- original_property=None):
+ def __init__(
+ self,
+ class_,
+ key,
+ descriptor,
+ comparator,
+ adapt_to_entity=None,
+ doc=None,
+ original_property=None,
+ ):
self.class_ = class_
self.key = key
self.descriptor = descriptor
@@ -284,15 +321,18 @@ def create_proxied_attribute(descriptor):
self._comparator = self._comparator()
if self._adapt_to_entity:
self._comparator = self._comparator.adapt_to_entity(
- self._adapt_to_entity)
+ self._adapt_to_entity
+ )
return self._comparator
def adapt_to_entity(self, adapt_to_entity):
- return self.__class__(adapt_to_entity.entity,
- self.key,
- self.descriptor,
- self._comparator,
- adapt_to_entity)
+ return self.__class__(
+ adapt_to_entity.entity,
+ self.key,
+ self.descriptor,
+ self._comparator,
+ adapt_to_entity,
+ )
def __get__(self, instance, owner):
if instance is None:
@@ -314,21 +354,24 @@ def create_proxied_attribute(descriptor):
return getattr(self.comparator, attribute)
except AttributeError:
raise AttributeError(
- 'Neither %r object nor %r object associated with %s '
- 'has an attribute %r' % (
+ "Neither %r object nor %r object associated with %s "
+ "has an attribute %r"
+ % (
type(descriptor).__name__,
type(self.comparator).__name__,
self,
- attribute)
+ attribute,
+ )
)
- Proxy.__name__ = type(descriptor).__name__ + 'Proxy'
+ Proxy.__name__ = type(descriptor).__name__ + "Proxy"
- util.monkeypatch_proxied_specials(Proxy, type(descriptor),
- name='descriptor',
- from_instance=descriptor)
+ util.monkeypatch_proxied_specials(
+ Proxy, type(descriptor), name="descriptor", from_instance=descriptor
+ )
return Proxy
+
OP_REMOVE = util.symbol("REMOVE")
OP_APPEND = util.symbol("APPEND")
OP_REPLACE = util.symbol("REPLACE")
@@ -364,7 +407,7 @@ class Event(object):
"""
- __slots__ = 'impl', 'op', 'parent_token'
+ __slots__ = "impl", "op", "parent_token"
def __init__(self, attribute_impl, op):
self.impl = attribute_impl
@@ -372,9 +415,11 @@ class Event(object):
self.parent_token = self.impl.parent_token
def __eq__(self, other):
- return isinstance(other, Event) and \
- other.impl is self.impl and \
- other.op == self.op
+ return (
+ isinstance(other, Event)
+ and other.impl is self.impl
+ and other.op == self.op
+ )
@property
def key(self):
@@ -387,12 +432,22 @@ class Event(object):
class AttributeImpl(object):
"""internal implementation for instrumented attributes."""
- def __init__(self, class_, key,
- callable_, dispatch, trackparent=False, extension=None,
- compare_function=None, active_history=False,
- parent_token=None, expire_missing=True,
- send_modified_events=True, accepts_scalar_loader=None,
- **kwargs):
+ def __init__(
+ self,
+ class_,
+ key,
+ callable_,
+ dispatch,
+ trackparent=False,
+ extension=None,
+ compare_function=None,
+ active_history=False,
+ parent_token=None,
+ expire_missing=True,
+ send_modified_events=True,
+ accepts_scalar_loader=None,
+ **kwargs
+ ):
r"""Construct an AttributeImpl.
\class_
@@ -471,9 +526,17 @@ class AttributeImpl(object):
self._modified_token = Event(self, OP_MODIFIED)
__slots__ = (
- 'class_', 'key', 'callable_', 'dispatch', 'trackparent',
- 'parent_token', 'send_modified_events', 'is_equal', 'expire_missing',
- '_modified_token', 'accepts_scalar_loader'
+ "class_",
+ "key",
+ "callable_",
+ "dispatch",
+ "trackparent",
+ "parent_token",
+ "send_modified_events",
+ "is_equal",
+ "expire_missing",
+ "_modified_token",
+ "accepts_scalar_loader",
)
def __str__(self):
@@ -508,8 +571,9 @@ class AttributeImpl(object):
msg = "This AttributeImpl is not configured to track parents."
assert self.trackparent, msg
- return state.parents.get(id(self.parent_token), optimistic) \
- is not False
+ return (
+ state.parents.get(id(self.parent_token), optimistic) is not False
+ )
def sethasparent(self, state, parent_state, value):
"""Set a boolean flag on the given item corresponding to
@@ -527,8 +591,10 @@ class AttributeImpl(object):
if id_ in state.parents:
last_parent = state.parents[id_]
- if last_parent is not False and \
- last_parent.key != parent_state.key:
+ if (
+ last_parent is not False
+ and last_parent.key != parent_state.key
+ ):
if last_parent.obj() is None:
raise orm_exc.StaleDataError(
@@ -536,10 +602,13 @@ class AttributeImpl(object):
"state %s along attribute '%s', "
"but the parent record "
"has gone stale, can't be sure this "
- "is the most recent parent." %
- (state_str(state),
- state_str(parent_state),
- self.key))
+ "is the most recent parent."
+ % (
+ state_str(state),
+ state_str(parent_state),
+ self.key,
+ )
+ )
return
@@ -588,8 +657,10 @@ class AttributeImpl(object):
else:
# if history present, don't load
key = self.key
- if key not in state.committed_state or \
- state.committed_state[key] is NEVER_SET:
+ if (
+ key not in state.committed_state
+ or state.committed_state[key] is NEVER_SET
+ ):
if not passive & CALLABLES_OK:
return PASSIVE_NO_RESULT
@@ -613,7 +684,8 @@ class AttributeImpl(object):
raise KeyError(
"Deferred loader for attribute "
"%r failed to populate "
- "correctly" % key)
+ "correctly" % key
+ )
elif value is not ATTR_EMPTY:
return self.set_committed_value(state, dict_, value)
@@ -627,15 +699,31 @@ class AttributeImpl(object):
self.set(state, dict_, value, initiator, passive=passive)
def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
- self.set(state, dict_, None, initiator,
- passive=passive, check_old=value)
+ self.set(
+ state, dict_, None, initiator, passive=passive, check_old=value
+ )
def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
- self.set(state, dict_, None, initiator,
- passive=passive, check_old=value, pop=True)
+ self.set(
+ state,
+ dict_,
+ None,
+ initiator,
+ passive=passive,
+ check_old=value,
+ pop=True,
+ )
- def set(self, state, dict_, value, initiator,
- passive=PASSIVE_OFF, check_old=None, pop=False):
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
raise NotImplementedError()
def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
@@ -667,7 +755,7 @@ class ScalarAttributeImpl(AttributeImpl):
collection = False
dynamic = False
- __slots__ = '_replace_token', '_append_token', '_remove_token'
+ __slots__ = "_replace_token", "_append_token", "_remove_token"
def __init__(self, *arg, **kw):
super(ScalarAttributeImpl, self).__init__(*arg, **kw)
@@ -685,10 +773,13 @@ class ScalarAttributeImpl(AttributeImpl):
state._modified_event(dict_, self, old)
existing = dict_.pop(self.key, NO_VALUE)
- if existing is NO_VALUE and old is NO_VALUE and \
- not state.expired and \
- self.key not in state.expired_attributes:
- raise AttributeError("%s object does not have a value" % self)
+ if (
+ existing is NO_VALUE
+ and old is NO_VALUE
+ and not state.expired
+ and self.key not in state.expired_attributes
+ ):
+ raise AttributeError("%s object does not have a value" % self)
def get_history(self, state, dict_, passive=PASSIVE_OFF):
if self.key in dict_:
@@ -702,23 +793,33 @@ class ScalarAttributeImpl(AttributeImpl):
else:
return History.from_scalar_attribute(self, state, current)
- def set(self, state, dict_, value, initiator,
- passive=PASSIVE_OFF, check_old=None, pop=False):
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
if self.dispatch._active_history:
old = self.get(state, dict_, PASSIVE_RETURN_NEVER_SET)
else:
old = dict_.get(self.key, NO_VALUE)
if self.dispatch.set:
- value = self.fire_replace_event(state, dict_,
- value, old, initiator)
+ value = self.fire_replace_event(
+ state, dict_, value, old, initiator
+ )
state._modified_event(dict_, self, old)
dict_[self.key] = value
def fire_replace_event(self, state, dict_, value, previous, initiator):
for fn in self.dispatch.set:
value = fn(
- state, value, previous, initiator or self._replace_token)
+ state, value, previous, initiator or self._replace_token
+ )
return value
def fire_remove_event(self, state, dict_, value, initiator):
@@ -748,13 +849,20 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
def delete(self, state, dict_):
if self.dispatch._active_history:
old = self.get(
- state, dict_,
- passive=PASSIVE_ONLY_PERSISTENT |
- NO_AUTOFLUSH | LOAD_AGAINST_COMMITTED)
+ state,
+ dict_,
+ passive=PASSIVE_ONLY_PERSISTENT
+ | NO_AUTOFLUSH
+ | LOAD_AGAINST_COMMITTED,
+ )
else:
old = self.get(
- state, dict_, passive=PASSIVE_NO_FETCH ^ INIT_OK |
- LOAD_AGAINST_COMMITTED | NO_RAISE)
+ state,
+ dict_,
+ passive=PASSIVE_NO_FETCH ^ INIT_OK
+ | LOAD_AGAINST_COMMITTED
+ | NO_RAISE,
+ )
self.fire_remove_event(state, dict_, old, self._remove_token)
@@ -763,8 +871,11 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
# if the attribute is expired, we currently have no way to tell
# that an object-attribute was expired vs. not loaded. So
# for this test, we look to see if the object has a DB identity.
- if existing is NO_VALUE and old is not PASSIVE_NO_RESULT and \
- state.key is None:
+ if (
+ existing is NO_VALUE
+ and old is not PASSIVE_NO_RESULT
+ and state.key is None
+ ):
raise AttributeError("%s object does not have a value" % self)
def get_history(self, state, dict_, passive=PASSIVE_OFF):
@@ -788,50 +899,69 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
return []
# can't use __hash__(), can't use __eq__() here
- if current is not None and \
- current is not PASSIVE_NO_RESULT and \
- current is not NEVER_SET:
+ if (
+ current is not None
+ and current is not PASSIVE_NO_RESULT
+ and current is not NEVER_SET
+ ):
ret = [(instance_state(current), current)]
else:
ret = [(None, None)]
if self.key in state.committed_state:
original = state.committed_state[self.key]
- if original is not None and \
- original is not PASSIVE_NO_RESULT and \
- original is not NEVER_SET and \
- original is not current:
+ if (
+ original is not None
+ and original is not PASSIVE_NO_RESULT
+ and original is not NEVER_SET
+ and original is not current
+ ):
ret.append((instance_state(original), original))
return ret
- def set(self, state, dict_, value, initiator,
- passive=PASSIVE_OFF, check_old=None, pop=False):
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
"""Set a value on the given InstanceState.
"""
if self.dispatch._active_history:
old = self.get(
- state, dict_,
- passive=PASSIVE_ONLY_PERSISTENT |
- NO_AUTOFLUSH | LOAD_AGAINST_COMMITTED)
+ state,
+ dict_,
+ passive=PASSIVE_ONLY_PERSISTENT
+ | NO_AUTOFLUSH
+ | LOAD_AGAINST_COMMITTED,
+ )
else:
old = self.get(
- state, dict_, passive=PASSIVE_NO_FETCH ^ INIT_OK |
- LOAD_AGAINST_COMMITTED | NO_RAISE)
+ state,
+ dict_,
+ passive=PASSIVE_NO_FETCH ^ INIT_OK
+ | LOAD_AGAINST_COMMITTED
+ | NO_RAISE,
+ )
- if check_old is not None and \
- old is not PASSIVE_NO_RESULT and \
- check_old is not old:
+ if (
+ check_old is not None
+ and old is not PASSIVE_NO_RESULT
+ and check_old is not old
+ ):
if pop:
return
else:
raise ValueError(
- "Object %s not associated with %s on attribute '%s'" % (
- instance_str(check_old),
- state_str(state),
- self.key
- ))
+ "Object %s not associated with %s on attribute '%s'"
+ % (instance_str(check_old), state_str(state), self.key)
+ )
value = self.fire_replace_event(state, dict_, value, old, initiator)
dict_[self.key] = value
@@ -847,13 +977,17 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
def fire_replace_event(self, state, dict_, value, previous, initiator):
if self.trackparent:
- if (previous is not value and
- previous not in (None, PASSIVE_NO_RESULT, NEVER_SET)):
+ if previous is not value and previous not in (
+ None,
+ PASSIVE_NO_RESULT,
+ NEVER_SET,
+ ):
self.sethasparent(instance_state(previous), state, False)
for fn in self.dispatch.set:
value = fn(
- state, value, previous, initiator or self._replace_token)
+ state, value, previous, initiator or self._replace_token
+ )
state._modified_event(dict_, self, previous)
@@ -875,6 +1009,7 @@ class CollectionAttributeImpl(AttributeImpl):
semantics to the orm layer independent of the user data implementation.
"""
+
default_accepts_scalar_loader = False
uses_objects = True
supports_population = True
@@ -882,21 +1017,37 @@ class CollectionAttributeImpl(AttributeImpl):
dynamic = False
__slots__ = (
- 'copy', 'collection_factory', '_append_token', '_remove_token',
- '_bulk_replace_token', '_duck_typed_as'
+ "copy",
+ "collection_factory",
+ "_append_token",
+ "_remove_token",
+ "_bulk_replace_token",
+ "_duck_typed_as",
)
- def __init__(self, class_, key, callable_, dispatch,
- typecallable=None, trackparent=False, extension=None,
- copy_function=None, compare_function=None, **kwargs):
+ def __init__(
+ self,
+ class_,
+ key,
+ callable_,
+ dispatch,
+ typecallable=None,
+ trackparent=False,
+ extension=None,
+ copy_function=None,
+ compare_function=None,
+ **kwargs
+ ):
super(CollectionAttributeImpl, self).__init__(
class_,
key,
- callable_, dispatch,
+ callable_,
+ dispatch,
trackparent=trackparent,
extension=extension,
compare_function=compare_function,
- **kwargs)
+ **kwargs
+ )
if copy_function is None:
copy_function = self.__copy
@@ -906,7 +1057,8 @@ class CollectionAttributeImpl(AttributeImpl):
self._remove_token = Event(self, OP_REMOVE)
self._bulk_replace_token = Event(self, OP_BULK_REPLACE)
self._duck_typed_as = util.duck_type_collection(
- self.collection_factory())
+ self.collection_factory()
+ )
if getattr(self.collection_factory, "_sa_linker", None):
@@ -935,35 +1087,42 @@ class CollectionAttributeImpl(AttributeImpl):
return []
current = dict_[self.key]
- current = getattr(current, '_sa_adapter')
+ current = getattr(current, "_sa_adapter")
if self.key in state.committed_state:
original = state.committed_state[self.key]
if original not in (NO_VALUE, NEVER_SET):
- current_states = [((c is not None) and
- instance_state(c) or None, c)
- for c in current]
- original_states = [((c is not None) and
- instance_state(c) or None, c)
- for c in original]
+ current_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in current
+ ]
+ original_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in original
+ ]
current_set = dict(current_states)
original_set = dict(original_states)
- return \
- [(s, o) for s, o in current_states
- if s not in original_set] + \
- [(s, o) for s, o in current_states
- if s in original_set] + \
- [(s, o) for s, o in original_states
- if s not in current_set]
+ return (
+ [
+ (s, o)
+ for s, o in current_states
+ if s not in original_set
+ ]
+ + [(s, o) for s, o in current_states if s in original_set]
+ + [
+ (s, o)
+ for s, o in original_states
+ if s not in current_set
+ ]
+ )
return [(instance_state(o), o) for o in current]
def fire_append_event(self, state, dict_, value, initiator):
for fn in self.dispatch.append:
- value = fn(
- state, value, initiator or self._append_token)
+ value = fn(state, value, initiator or self._append_token)
state._modified_event(dict_, self, NEVER_SET, True)
@@ -1015,7 +1174,8 @@ class CollectionAttributeImpl(AttributeImpl):
def _initialize_collection(self, state):
adapter, collection = state.manager.initialize_collection(
- self.key, state, self.collection_factory)
+ self.key, state, self.collection_factory
+ )
self.dispatch.init_collection(state, collection, adapter)
@@ -1025,8 +1185,9 @@ class CollectionAttributeImpl(AttributeImpl):
collection = self.get_collection(state, dict_, passive=passive)
if collection is PASSIVE_NO_RESULT:
value = self.fire_append_event(state, dict_, value, initiator)
- assert self.key not in dict_, \
- "Collection was loaded during event handling."
+ assert (
+ self.key not in dict_
+ ), "Collection was loaded during event handling."
state._get_pending_mutation(self.key).append(value)
else:
collection.append_with_event(value, initiator)
@@ -1035,8 +1196,9 @@ class CollectionAttributeImpl(AttributeImpl):
collection = self.get_collection(state, state.dict, passive=passive)
if collection is PASSIVE_NO_RESULT:
self.fire_remove_event(state, dict_, value, initiator)
- assert self.key not in dict_, \
- "Collection was loaded during event handling."
+ assert (
+ self.key not in dict_
+ ), "Collection was loaded during event handling."
state._get_pending_mutation(self.key).remove(value)
else:
collection.remove_with_event(value, initiator)
@@ -1050,8 +1212,16 @@ class CollectionAttributeImpl(AttributeImpl):
except (ValueError, KeyError, IndexError):
pass
- def set(self, state, dict_, value, initiator=None,
- passive=PASSIVE_OFF, pop=False, _adapt=True):
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator=None,
+ passive=PASSIVE_OFF,
+ pop=False,
+ _adapt=True,
+ ):
iterable = orig_iterable = value
# pulling a new collection first so that an adaptation exception does
@@ -1065,23 +1235,28 @@ class CollectionAttributeImpl(AttributeImpl):
receiving_type = self._duck_typed_as
if setting_type is not receiving_type:
- given = iterable is None and 'None' or \
- iterable.__class__.__name__
+ given = (
+ iterable is None
+ and "None"
+ or iterable.__class__.__name__
+ )
wanted = self._duck_typed_as.__name__
raise TypeError(
- "Incompatible collection type: %s is not %s-like" % (
- given, wanted))
+ "Incompatible collection type: %s is not %s-like"
+ % (given, wanted)
+ )
# If the object is an adapted collection, return the (iterable)
# adapter.
- if hasattr(iterable, '_sa_iterator'):
+ if hasattr(iterable, "_sa_iterator"):
iterable = iterable._sa_iterator()
elif setting_type is dict:
if util.py3k:
iterable = iterable.values()
else:
iterable = getattr(
- iterable, 'itervalues', iterable.values)()
+ iterable, "itervalues", iterable.values
+ )()
else:
iterable = iter(iterable)
new_values = list(iterable)
@@ -1106,14 +1281,14 @@ class CollectionAttributeImpl(AttributeImpl):
dict_[self.key] = user_data
collections.bulk_replace(
- new_values, old_collection, new_collection,
- initiator=evt)
+ new_values, old_collection, new_collection, initiator=evt
+ )
del old._sa_adapter
self.dispatch.dispose_collection(state, old, old_collection)
def _invalidate_collection(self, collection):
- adapter = getattr(collection, '_sa_adapter')
+ adapter = getattr(collection, "_sa_adapter")
adapter.invalidated = True
def set_committed_value(self, state, dict_, value):
@@ -1143,8 +1318,9 @@ class CollectionAttributeImpl(AttributeImpl):
return user_data
- def get_collection(self, state, dict_,
- user_data=None, passive=PASSIVE_OFF):
+ def get_collection(
+ self, state, dict_, user_data=None, passive=PASSIVE_OFF
+ ):
"""Retrieve the CollectionAdapter associated with the given state.
Creates a new CollectionAdapter if one does not exist.
@@ -1155,7 +1331,7 @@ class CollectionAttributeImpl(AttributeImpl):
if user_data is PASSIVE_NO_RESULT:
return user_data
- return getattr(user_data, '_sa_adapter')
+ return getattr(user_data, "_sa_adapter")
def backref_listeners(attribute, key, uselist):
@@ -1177,24 +1353,29 @@ def backref_listeners(attribute, key, uselist):
"Bidirectional attribute conflict detected: "
'Passing object %s to attribute "%s" '
'triggers a modify event on attribute "%s" '
- 'via the backref "%s".' % (
+ 'via the backref "%s".'
+ % (
state_str(child_state),
initiator.parent_token,
child_impl.parent_token,
- attribute.impl.parent_token
+ attribute.impl.parent_token,
)
)
def emit_backref_from_scalar_set_event(state, child, oldchild, initiator):
if oldchild is child:
return child
- if oldchild is not None and \
- oldchild is not PASSIVE_NO_RESULT and \
- oldchild is not NEVER_SET:
+ if (
+ oldchild is not None
+ and oldchild is not PASSIVE_NO_RESULT
+ and oldchild is not NEVER_SET
+ ):
# With lazy=None, there's no guarantee that the full collection is
# present when updating via a backref.
- old_state, old_dict = instance_state(oldchild),\
- instance_dict(oldchild)
+ old_state, old_dict = (
+ instance_state(oldchild),
+ instance_dict(oldchild),
+ )
impl = old_state.manager[key].impl
# tokens to test for a recursive loop.
@@ -1204,69 +1385,90 @@ def backref_listeners(attribute, key, uselist):
check_recursive_token = impl._remove_token
if initiator is not check_recursive_token:
- impl.pop(old_state,
- old_dict,
- state.obj(),
- parent_impl._append_token,
- passive=PASSIVE_NO_FETCH)
+ impl.pop(
+ old_state,
+ old_dict,
+ state.obj(),
+ parent_impl._append_token,
+ passive=PASSIVE_NO_FETCH,
+ )
if child is not None:
- child_state, child_dict = instance_state(child),\
- instance_dict(child)
+ child_state, child_dict = (
+ instance_state(child),
+ instance_dict(child),
+ )
child_impl = child_state.manager[key].impl
- if initiator.parent_token is not parent_token and \
- initiator.parent_token is not child_impl.parent_token:
+ if (
+ initiator.parent_token is not parent_token
+ and initiator.parent_token is not child_impl.parent_token
+ ):
_acceptable_key_err(state, initiator, child_impl)
# tokens to test for a recursive loop.
check_append_token = child_impl._append_token
- check_bulk_replace_token = child_impl._bulk_replace_token \
- if child_impl.collection else None
+ check_bulk_replace_token = (
+ child_impl._bulk_replace_token
+ if child_impl.collection
+ else None
+ )
- if initiator is not check_append_token and \
- initiator is not check_bulk_replace_token:
+ if (
+ initiator is not check_append_token
+ and initiator is not check_bulk_replace_token
+ ):
child_impl.append(
child_state,
child_dict,
state.obj(),
initiator,
- passive=PASSIVE_NO_FETCH)
+ passive=PASSIVE_NO_FETCH,
+ )
return child
def emit_backref_from_collection_append_event(state, child, initiator):
if child is None:
return
- child_state, child_dict = instance_state(child), \
- instance_dict(child)
+ child_state, child_dict = instance_state(child), instance_dict(child)
child_impl = child_state.manager[key].impl
- if initiator.parent_token is not parent_token and \
- initiator.parent_token is not child_impl.parent_token:
+ if (
+ initiator.parent_token is not parent_token
+ and initiator.parent_token is not child_impl.parent_token
+ ):
_acceptable_key_err(state, initiator, child_impl)
# tokens to test for a recursive loop.
check_append_token = child_impl._append_token
- check_bulk_replace_token = child_impl._bulk_replace_token \
- if child_impl.collection else None
+ check_bulk_replace_token = (
+ child_impl._bulk_replace_token if child_impl.collection else None
+ )
- if initiator is not check_append_token and \
- initiator is not check_bulk_replace_token:
+ if (
+ initiator is not check_append_token
+ and initiator is not check_bulk_replace_token
+ ):
child_impl.append(
child_state,
child_dict,
state.obj(),
initiator,
- passive=PASSIVE_NO_FETCH)
+ passive=PASSIVE_NO_FETCH,
+ )
return child
def emit_backref_from_collection_remove_event(state, child, initiator):
- if child is not None and \
- child is not PASSIVE_NO_RESULT and \
- child is not NEVER_SET:
- child_state, child_dict = instance_state(child),\
- instance_dict(child)
+ if (
+ child is not None
+ and child is not PASSIVE_NO_RESULT
+ and child is not NEVER_SET
+ ):
+ child_state, child_dict = (
+ instance_state(child),
+ instance_dict(child),
+ )
child_impl = child_state.manager[key].impl
# tokens to test for a recursive loop.
@@ -1276,47 +1478,64 @@ def backref_listeners(attribute, key, uselist):
check_for_dupes_on_remove = uselist and not parent_impl.dynamic
else:
check_remove_token = child_impl._remove_token
- check_replace_token = child_impl._bulk_replace_token \
- if child_impl.collection else None
+ check_replace_token = (
+ child_impl._bulk_replace_token
+ if child_impl.collection
+ else None
+ )
check_for_dupes_on_remove = False
- if initiator is not check_remove_token and \
- initiator is not check_replace_token:
-
- if not check_for_dupes_on_remove or \
- not util.has_dupes(
- # when this event is called, the item is usually
- # present in the list, except for a pop() operation.
- state.dict[parent_impl.key], child):
+ if (
+ initiator is not check_remove_token
+ and initiator is not check_replace_token
+ ):
+
+ if not check_for_dupes_on_remove or not util.has_dupes(
+ # when this event is called, the item is usually
+ # present in the list, except for a pop() operation.
+ state.dict[parent_impl.key],
+ child,
+ ):
child_impl.pop(
child_state,
child_dict,
state.obj(),
initiator,
- passive=PASSIVE_NO_FETCH)
+ passive=PASSIVE_NO_FETCH,
+ )
if uselist:
- event.listen(attribute, "append",
- emit_backref_from_collection_append_event,
- retval=True, raw=True)
+ event.listen(
+ attribute,
+ "append",
+ emit_backref_from_collection_append_event,
+ retval=True,
+ raw=True,
+ )
else:
- event.listen(attribute, "set",
- emit_backref_from_scalar_set_event,
- retval=True, raw=True)
+ event.listen(
+ attribute,
+ "set",
+ emit_backref_from_scalar_set_event,
+ retval=True,
+ raw=True,
+ )
# TODO: need coverage in test/orm/ of remove event
- event.listen(attribute, "remove",
- emit_backref_from_collection_remove_event,
- retval=True, raw=True)
+ event.listen(
+ attribute,
+ "remove",
+ emit_backref_from_collection_remove_event,
+ retval=True,
+ raw=True,
+ )
-_NO_HISTORY = util.symbol('NO_HISTORY')
-_NO_STATE_SYMBOLS = frozenset([
- id(PASSIVE_NO_RESULT),
- id(NO_VALUE),
- id(NEVER_SET)])
-History = util.namedtuple("History", [
- "added", "unchanged", "deleted"
-])
+_NO_HISTORY = util.symbol("NO_HISTORY")
+_NO_STATE_SYMBOLS = frozenset(
+ [id(PASSIVE_NO_RESULT), id(NO_VALUE), id(NEVER_SET)]
+)
+
+History = util.namedtuple("History", ["added", "unchanged", "deleted"])
class History(History):
@@ -1346,6 +1565,7 @@ class History(History):
def __bool__(self):
return self != HISTORY_BLANK
+
__nonzero__ = __bool__
def empty(self):
@@ -1354,29 +1574,24 @@ class History(History):
"""
- return not bool(
- (self.added or self.deleted)
- or self.unchanged
- )
+ return not bool((self.added or self.deleted) or self.unchanged)
def sum(self):
"""Return a collection of added + unchanged + deleted."""
- return (self.added or []) +\
- (self.unchanged or []) +\
- (self.deleted or [])
+ return (
+ (self.added or []) + (self.unchanged or []) + (self.deleted or [])
+ )
def non_deleted(self):
"""Return a collection of added + unchanged."""
- return (self.added or []) +\
- (self.unchanged or [])
+ return (self.added or []) + (self.unchanged or [])
def non_added(self):
"""Return a collection of unchanged + deleted."""
- return (self.unchanged or []) +\
- (self.deleted or [])
+ return (self.unchanged or []) + (self.deleted or [])
def has_changes(self):
"""Return True if this :class:`.History` has changes."""
@@ -1385,15 +1600,18 @@ class History(History):
def as_state(self):
return History(
- [(c is not None)
- and instance_state(c) or None
- for c in self.added],
- [(c is not None)
- and instance_state(c) or None
- for c in self.unchanged],
- [(c is not None)
- and instance_state(c) or None
- for c in self.deleted],
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.added
+ ],
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.unchanged
+ ],
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.deleted
+ ],
)
@classmethod
@@ -1464,21 +1682,21 @@ class History(History):
if current is NO_VALUE or current is NEVER_SET:
return cls((), (), ())
- current = getattr(current, '_sa_adapter')
+ current = getattr(current, "_sa_adapter")
if original in (NO_VALUE, NEVER_SET):
return cls(list(current), (), ())
elif original is _NO_HISTORY:
return cls((), list(current), ())
else:
- current_states = [((c is not None) and instance_state(c)
- or None, c)
- for c in current
- ]
- original_states = [((c is not None) and instance_state(c)
- or None, c)
- for c in original
- ]
+ current_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in current
+ ]
+ original_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in original
+ ]
current_set = dict(current_states)
original_set = dict(original_states)
@@ -1486,9 +1704,10 @@ class History(History):
return cls(
[o for s, o in current_states if s not in original_set],
[o for s, o in current_states if s in original_set],
- [o for s, o in original_states if s not in current_set]
+ [o for s, o in original_states if s not in current_set],
)
+
HISTORY_BLANK = History(None, None, None)
@@ -1509,12 +1728,16 @@ def get_history(obj, key, passive=PASSIVE_OFF):
"""
if passive is True:
- util.warn_deprecated("Passing True for 'passive' is deprecated. "
- "Use attributes.PASSIVE_NO_INITIALIZE")
+ util.warn_deprecated(
+ "Passing True for 'passive' is deprecated. "
+ "Use attributes.PASSIVE_NO_INITIALIZE"
+ )
passive = PASSIVE_NO_INITIALIZE
elif passive is False:
- util.warn_deprecated("Passing False for 'passive' is "
- "deprecated. Use attributes.PASSIVE_OFF")
+ util.warn_deprecated(
+ "Passing False for 'passive' is "
+ "deprecated. Use attributes.PASSIVE_OFF"
+ )
passive = PASSIVE_OFF
return get_state_history(instance_state(obj), key, passive)
@@ -1532,38 +1755,46 @@ def has_parent(cls, obj, key, optimistic=False):
def register_attribute(class_, key, **kw):
- comparator = kw.pop('comparator', None)
- parententity = kw.pop('parententity', None)
- doc = kw.pop('doc', None)
- desc = register_descriptor(class_, key,
- comparator, parententity, doc=doc)
+ comparator = kw.pop("comparator", None)
+ parententity = kw.pop("parententity", None)
+ doc = kw.pop("doc", None)
+ desc = register_descriptor(class_, key, comparator, parententity, doc=doc)
register_attribute_impl(class_, key, **kw)
return desc
-def register_attribute_impl(class_, key,
- uselist=False, callable_=None,
- useobject=False,
- impl_class=None, backref=None, **kw):
+def register_attribute_impl(
+ class_,
+ key,
+ uselist=False,
+ callable_=None,
+ useobject=False,
+ impl_class=None,
+ backref=None,
+ **kw
+):
manager = manager_of_class(class_)
if uselist:
- factory = kw.pop('typecallable', None)
+ factory = kw.pop("typecallable", None)
typecallable = manager.instrument_collection_class(
- key, factory or list)
+ key, factory or list
+ )
else:
- typecallable = kw.pop('typecallable', None)
+ typecallable = kw.pop("typecallable", None)
dispatch = manager[key].dispatch
if impl_class:
impl = impl_class(class_, key, typecallable, dispatch, **kw)
elif uselist:
- impl = CollectionAttributeImpl(class_, key, callable_, dispatch,
- typecallable=typecallable, **kw)
+ impl = CollectionAttributeImpl(
+ class_, key, callable_, dispatch, typecallable=typecallable, **kw
+ )
elif useobject:
- impl = ScalarObjectAttributeImpl(class_, key, callable_,
- dispatch, **kw)
+ impl = ScalarObjectAttributeImpl(
+ class_, key, callable_, dispatch, **kw
+ )
else:
impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw)
@@ -1576,12 +1807,14 @@ def register_attribute_impl(class_, key,
return manager[key]
-def register_descriptor(class_, key, comparator=None,
- parententity=None, doc=None):
+def register_descriptor(
+ class_, key, comparator=None, parententity=None, doc=None
+):
manager = manager_of_class(class_)
- descriptor = InstrumentedAttribute(class_, key, comparator=comparator,
- parententity=parententity)
+ descriptor = InstrumentedAttribute(
+ class_, key, comparator=comparator, parententity=parententity
+ )
descriptor.__doc__ = doc
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index deddaa5a4..abc572d9a 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -15,66 +15,69 @@ from . import exc
import operator
PASSIVE_NO_RESULT = util.symbol(
- 'PASSIVE_NO_RESULT',
+ "PASSIVE_NO_RESULT",
"""Symbol returned by a loader callable or other attribute/history
retrieval operation when a value could not be determined, based
on loader callable flags.
- """
+ """,
)
ATTR_WAS_SET = util.symbol(
- 'ATTR_WAS_SET',
+ "ATTR_WAS_SET",
"""Symbol returned by a loader callable to indicate the
retrieved value, or values, were assigned to their attributes
on the target object.
- """
+ """,
)
ATTR_EMPTY = util.symbol(
- 'ATTR_EMPTY',
- """Symbol used internally to indicate an attribute had no callable."""
+ "ATTR_EMPTY",
+ """Symbol used internally to indicate an attribute had no callable.""",
)
NO_VALUE = util.symbol(
- 'NO_VALUE',
+ "NO_VALUE",
"""Symbol which may be placed as the 'previous' value of an attribute,
indicating no value was loaded for an attribute when it was modified,
and flags indicated we were not to load it.
- """
+ """,
)
NEVER_SET = util.symbol(
- 'NEVER_SET',
+ "NEVER_SET",
"""Symbol which may be placed as the 'previous' value of an attribute
indicating that the attribute had not been assigned to previously.
- """
+ """,
)
NO_CHANGE = util.symbol(
"NO_CHANGE",
"""No callables or SQL should be emitted on attribute access
and no state should change
- """, canonical=0
+ """,
+ canonical=0,
)
CALLABLES_OK = util.symbol(
"CALLABLES_OK",
"""Loader callables can be fired off if a value
is not present.
- """, canonical=1
+ """,
+ canonical=1,
)
SQL_OK = util.symbol(
"SQL_OK",
"""Loader callables can emit SQL at least on scalar value attributes.""",
- canonical=2
+ canonical=2,
)
RELATED_OBJECT_OK = util.symbol(
"RELATED_OBJECT_OK",
"""Callables can use SQL to load related objects as well
as scalar value attributes.
- """, canonical=4
+ """,
+ canonical=4,
)
INIT_OK = util.symbol(
@@ -82,111 +85,116 @@ INIT_OK = util.symbol(
"""Attributes should be initialized with a blank
value (None or an empty collection) upon get, if no other
value can be obtained.
- """, canonical=8
+ """,
+ canonical=8,
)
NON_PERSISTENT_OK = util.symbol(
"NON_PERSISTENT_OK",
"""Callables can be emitted if the parent is not persistent.""",
- canonical=16
+ canonical=16,
)
LOAD_AGAINST_COMMITTED = util.symbol(
"LOAD_AGAINST_COMMITTED",
"""Callables should use committed values as primary/foreign keys during a
load.
- """, canonical=32
+ """,
+ canonical=32,
)
NO_AUTOFLUSH = util.symbol(
"NO_AUTOFLUSH",
"""Loader callables should disable autoflush.""",
- canonical=64
+ canonical=64,
)
NO_RAISE = util.symbol(
"NO_RAISE",
"""Loader callables should not raise any assertions""",
- canonical=128
+ canonical=128,
)
# pre-packaged sets of flags used as inputs
PASSIVE_OFF = util.symbol(
"PASSIVE_OFF",
"Callables can be emitted in all cases.",
- canonical=(RELATED_OBJECT_OK | NON_PERSISTENT_OK |
- INIT_OK | CALLABLES_OK | SQL_OK)
+ canonical=(
+ RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK
+ ),
)
PASSIVE_RETURN_NEVER_SET = util.symbol(
"PASSIVE_RETURN_NEVER_SET",
"""PASSIVE_OFF ^ INIT_OK""",
- canonical=PASSIVE_OFF ^ INIT_OK
+ canonical=PASSIVE_OFF ^ INIT_OK,
)
PASSIVE_NO_INITIALIZE = util.symbol(
"PASSIVE_NO_INITIALIZE",
"PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK",
- canonical=PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK
+ canonical=PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK,
)
PASSIVE_NO_FETCH = util.symbol(
- "PASSIVE_NO_FETCH",
- "PASSIVE_OFF ^ SQL_OK",
- canonical=PASSIVE_OFF ^ SQL_OK
+ "PASSIVE_NO_FETCH", "PASSIVE_OFF ^ SQL_OK", canonical=PASSIVE_OFF ^ SQL_OK
)
PASSIVE_NO_FETCH_RELATED = util.symbol(
"PASSIVE_NO_FETCH_RELATED",
"PASSIVE_OFF ^ RELATED_OBJECT_OK",
- canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK
+ canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK,
)
PASSIVE_ONLY_PERSISTENT = util.symbol(
"PASSIVE_ONLY_PERSISTENT",
"PASSIVE_OFF ^ NON_PERSISTENT_OK",
- canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK
+ canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK,
)
-DEFAULT_MANAGER_ATTR = '_sa_class_manager'
-DEFAULT_STATE_ATTR = '_sa_instance_state'
-_INSTRUMENTOR = ('mapper', 'instrumentor')
+DEFAULT_MANAGER_ATTR = "_sa_class_manager"
+DEFAULT_STATE_ATTR = "_sa_instance_state"
+_INSTRUMENTOR = ("mapper", "instrumentor")
-EXT_CONTINUE = util.symbol('EXT_CONTINUE')
-EXT_STOP = util.symbol('EXT_STOP')
-EXT_SKIP = util.symbol('EXT_SKIP')
+EXT_CONTINUE = util.symbol("EXT_CONTINUE")
+EXT_STOP = util.symbol("EXT_STOP")
+EXT_SKIP = util.symbol("EXT_SKIP")
ONETOMANY = util.symbol(
- 'ONETOMANY',
+ "ONETOMANY",
"""Indicates the one-to-many direction for a :func:`.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """)
+ """,
+)
MANYTOONE = util.symbol(
- 'MANYTOONE',
+ "MANYTOONE",
"""Indicates the many-to-one direction for a :func:`.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """)
+ """,
+)
MANYTOMANY = util.symbol(
- 'MANYTOMANY',
+ "MANYTOMANY",
"""Indicates the many-to-many direction for a :func:`.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """)
+ """,
+)
NOT_EXTENSION = util.symbol(
- 'NOT_EXTENSION',
+ "NOT_EXTENSION",
"""Symbol indicating an :class:`InspectionAttr` that's
not part of sqlalchemy.ext.
Is assigned to the :attr:`.InspectionAttr.extension_type`
attibute.
- """)
+ """,
+)
_never_set = frozenset([NEVER_SET])
@@ -207,6 +215,7 @@ def _generative(*assertions):
assertion(self, fn.__name__)
fn(self, *args[1:], **kw)
return self
+
return generate
@@ -215,9 +224,10 @@ def _generative(*assertions):
def manager_of_class(cls):
return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
+
instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
-instance_dict = operator.attrgetter('__dict__')
+instance_dict = operator.attrgetter("__dict__")
def instance_str(instance):
@@ -232,7 +242,7 @@ def state_str(state):
if state is None:
return "None"
else:
- return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
+ return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj()))
def state_class_str(state):
@@ -243,7 +253,7 @@ def state_class_str(state):
if state is None:
return "None"
else:
- return '<%s>' % (state.class_.__name__, )
+ return "<%s>" % (state.class_.__name__,)
def attribute_str(instance, attribute):
@@ -335,15 +345,15 @@ def _is_mapped_class(entity):
"""
insp = inspection.inspect(entity, False)
- return insp is not None and \
- not insp.is_clause_element and \
- (
- insp.is_mapper or insp.is_aliased_class
- )
+ return (
+ insp is not None
+ and not insp.is_clause_element
+ and (insp.is_mapper or insp.is_aliased_class)
+ )
def _attr_as_key(attr):
- if hasattr(attr, 'key'):
+ if hasattr(attr, "key"):
return attr.key
else:
return expression._column_as_key(attr)
@@ -351,7 +361,7 @@ def _attr_as_key(attr):
def _orm_columns(entity):
insp = inspection.inspect(entity, False)
- if hasattr(insp, 'selectable') and hasattr(insp.selectable, 'c'):
+ if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"):
return [c for c in insp.selectable.c]
else:
return [entity]
@@ -359,8 +369,7 @@ def _orm_columns(entity):
def _is_aliased_class(entity):
insp = inspection.inspect(entity, False)
- return insp is not None and \
- getattr(insp, "is_aliased_class", False)
+ return insp is not None and getattr(insp, "is_aliased_class", False)
def _entity_descriptor(entity, key):
@@ -386,11 +395,11 @@ def _entity_descriptor(entity, key):
return getattr(entity, key)
except AttributeError:
raise sa_exc.InvalidRequestError(
- "Entity '%s' has no property '%s'" %
- (description, key)
+ "Entity '%s' has no property '%s'" % (description, key)
)
-_state_mapper = util.dottedgetter('manager.mapper')
+
+_state_mapper = util.dottedgetter("manager.mapper")
@inspection._inspects(type)
@@ -429,7 +438,8 @@ def class_mapper(class_, configure=True):
if mapper is None:
if not isinstance(class_, type):
raise sa_exc.ArgumentError(
- "Class object expected, got '%r'." % (class_, ))
+ "Class object expected, got '%r'." % (class_,)
+ )
raise exc.UnmappedClassError(class_)
else:
return mapper
@@ -449,6 +459,7 @@ class InspectionAttr(object):
here intact for forwards-compatibility.
"""
+
__slots__ = ()
is_selectable = False
@@ -551,4 +562,5 @@ class _MappedAttribute(object):
attributes.
"""
+
__slots__ = ()
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index 54c29bb5e..be9291741 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -113,9 +113,13 @@ from . import base
from sqlalchemy.util.compat import inspect_getargspec
-__all__ = ['collection', 'collection_adapter',
- 'mapped_collection', 'column_mapped_collection',
- 'attribute_mapped_collection']
+__all__ = [
+ "collection",
+ "collection_adapter",
+ "mapped_collection",
+ "column_mapped_collection",
+ "attribute_mapped_collection",
+]
__instrumentation_mutex = util.threading.Lock()
@@ -172,10 +176,12 @@ class _SerializableColumnGetter(object):
def __call__(self, value):
state = base.instance_state(value)
m = base._state_mapper(state)
- key = [m._get_state_attr_by_column(
- state, state.dict,
- m.mapped_table.columns[k])
- for k in self.colkeys]
+ key = [
+ m._get_state_attr_by_column(
+ state, state.dict, m.mapped_table.columns[k]
+ )
+ for k in self.colkeys
+ ]
if self.composite:
return tuple(key)
else:
@@ -208,16 +214,15 @@ class _SerializableColumnGetterV2(_PlainColumnGetter):
return None
else:
return c.table.key
+
colkeys = [(c.key, _table_key(c)) for c in cols]
return _SerializableColumnGetterV2, (colkeys,)
def _cols(self, mapper):
cols = []
- metadata = getattr(mapper.local_table, 'metadata', None)
+ metadata = getattr(mapper.local_table, "metadata", None)
for (ckey, tkey) in self.colkeys:
- if tkey is None or \
- metadata is None or \
- tkey not in metadata:
+ if tkey is None or metadata is None or tkey not in metadata:
cols.append(mapper.local_table.c[ckey])
else:
cols.append(metadata.tables[tkey].c[ckey])
@@ -237,9 +242,10 @@ def column_mapped_collection(mapping_spec):
after a session flush.
"""
- cols = [expression._only_column_elements(q, "mapping_spec")
- for q in util.to_list(mapping_spec)
- ]
+ cols = [
+ expression._only_column_elements(q, "mapping_spec")
+ for q in util.to_list(mapping_spec)
+ ]
keyfunc = _PlainColumnGetter(cols)
return lambda: MappedCollection(keyfunc)
@@ -253,7 +259,7 @@ class _SerializableAttrGetter(object):
return self.getter(target)
def __reduce__(self):
- return _SerializableAttrGetter, (self.name, )
+ return _SerializableAttrGetter, (self.name,)
def attribute_mapped_collection(attr_name):
@@ -311,6 +317,7 @@ class collection(object):
def popitem(self): ...
"""
+
# Bundled as a class solely for ease of use: packaging, doc strings,
# importability.
@@ -355,7 +362,7 @@ class collection(object):
promulgation to collection events.
"""
- fn._sa_instrument_role = 'appender'
+ fn._sa_instrument_role = "appender"
return fn
@staticmethod
@@ -382,7 +389,7 @@ class collection(object):
promulgation to collection events.
"""
- fn._sa_instrument_role = 'remover'
+ fn._sa_instrument_role = "remover"
return fn
@staticmethod
@@ -396,7 +403,7 @@ class collection(object):
def __iter__(self): ...
"""
- fn._sa_instrument_role = 'iterator'
+ fn._sa_instrument_role = "iterator"
return fn
@staticmethod
@@ -435,7 +442,7 @@ class collection(object):
and :meth:`.AttributeEvents.dispose_collection` handlers.
"""
- fn._sa_instrument_role = 'linker'
+ fn._sa_instrument_role = "linker"
return fn
link = linker
@@ -472,7 +479,7 @@ class collection(object):
validation on the values about to be assigned.
"""
- fn._sa_instrument_role = 'converter'
+ fn._sa_instrument_role = "converter"
return fn
@staticmethod
@@ -491,9 +498,11 @@ class collection(object):
def do_stuff(self, thing, entity=None): ...
"""
+
def decorator(fn):
- fn._sa_instrument_before = ('fire_append_event', arg)
+ fn._sa_instrument_before = ("fire_append_event", arg)
return fn
+
return decorator
@staticmethod
@@ -511,10 +520,12 @@ class collection(object):
def __setitem__(self, index, item): ...
"""
+
def decorator(fn):
- fn._sa_instrument_before = ('fire_append_event', arg)
- fn._sa_instrument_after = 'fire_remove_event'
+ fn._sa_instrument_before = ("fire_append_event", arg)
+ fn._sa_instrument_after = "fire_remove_event"
return fn
+
return decorator
@staticmethod
@@ -533,9 +544,11 @@ class collection(object):
collection.removes_return.
"""
+
def decorator(fn):
- fn._sa_instrument_before = ('fire_remove_event', arg)
+ fn._sa_instrument_before = ("fire_remove_event", arg)
return fn
+
return decorator
@staticmethod
@@ -553,13 +566,15 @@ class collection(object):
collection.remove.
"""
+
def decorator(fn):
- fn._sa_instrument_after = 'fire_remove_event'
+ fn._sa_instrument_after = "fire_remove_event"
return fn
+
return decorator
-collection_adapter = operator.attrgetter('_sa_adapter')
+collection_adapter = operator.attrgetter("_sa_adapter")
"""Fetch the :class:`.CollectionAdapter` for a collection."""
@@ -577,7 +592,13 @@ class CollectionAdapter(object):
"""
__slots__ = (
- 'attr', '_key', '_data', 'owner_state', '_converter', 'invalidated')
+ "attr",
+ "_key",
+ "_data",
+ "owner_state",
+ "_converter",
+ "invalidated",
+ )
def __init__(self, attr, owner_state, data):
self.attr = attr
@@ -676,9 +697,8 @@ class CollectionAdapter(object):
if self.invalidated:
self._warn_invalidated()
return self.attr.fire_append_event(
- self.owner_state,
- self.owner_state.dict,
- item, initiator)
+ self.owner_state, self.owner_state.dict, item, initiator
+ )
else:
return item
@@ -694,9 +714,8 @@ class CollectionAdapter(object):
if self.invalidated:
self._warn_invalidated()
self.attr.fire_remove_event(
- self.owner_state,
- self.owner_state.dict,
- item, initiator)
+ self.owner_state, self.owner_state.dict, item, initiator
+ )
def fire_pre_remove_event(self, initiator=None):
"""Notify that an entity is about to be removed from the collection.
@@ -708,25 +727,26 @@ class CollectionAdapter(object):
if self.invalidated:
self._warn_invalidated()
self.attr.fire_pre_remove_event(
- self.owner_state,
- self.owner_state.dict,
- initiator=initiator)
+ self.owner_state, self.owner_state.dict, initiator=initiator
+ )
def __getstate__(self):
- return {'key': self._key,
- 'owner_state': self.owner_state,
- 'owner_cls': self.owner_state.class_,
- 'data': self.data,
- 'invalidated': self.invalidated}
+ return {
+ "key": self._key,
+ "owner_state": self.owner_state,
+ "owner_cls": self.owner_state.class_,
+ "data": self.data,
+ "invalidated": self.invalidated,
+ }
def __setstate__(self, d):
- self._key = d['key']
- self.owner_state = d['owner_state']
- self._data = weakref.ref(d['data'])
- self._converter = d['data']._sa_converter
- d['data']._sa_adapter = self
- self.invalidated = d['invalidated']
- self.attr = getattr(d['owner_cls'], self._key).impl
+ self._key = d["key"]
+ self.owner_state = d["owner_state"]
+ self._data = weakref.ref(d["data"])
+ self._converter = d["data"]._sa_converter
+ d["data"]._sa_adapter = self
+ self.invalidated = d["invalidated"]
+ self.attr = getattr(d["owner_cls"], self._key).impl
def bulk_replace(values, existing_adapter, new_adapter, initiator=None):
@@ -796,7 +816,7 @@ def prepare_instrumentation(factory):
# Instrument the class if needed.
if __instrumentation_mutex.acquire():
try:
- if getattr(cls, '_sa_instrumented', None) != id(cls):
+ if getattr(cls, "_sa_instrumented", None) != id(cls):
_instrument_class(cls)
finally:
__instrumentation_mutex.release()
@@ -829,10 +849,11 @@ def _instrument_class(cls):
# In the normal call flow, a request for any of the 3 basic collection
# types is transformed into one of our trivial subclasses
# (e.g. InstrumentedList). Catch anything else that sneaks in here...
- if cls.__module__ == '__builtin__':
+ if cls.__module__ == "__builtin__":
raise sa_exc.ArgumentError(
"Can not instrument a built-in type. Use a "
- "subclass, even a trivial one.")
+ "subclass, even a trivial one."
+ )
roles, methods = _locate_roles_and_methods(cls)
@@ -858,25 +879,30 @@ def _locate_roles_and_methods(cls):
continue
# note role declarations
- if hasattr(method, '_sa_instrument_role'):
+ if hasattr(method, "_sa_instrument_role"):
role = method._sa_instrument_role
- assert role in ('appender', 'remover', 'iterator',
- 'linker', 'converter')
+ assert role in (
+ "appender",
+ "remover",
+ "iterator",
+ "linker",
+ "converter",
+ )
roles.setdefault(role, name)
# transfer instrumentation requests from decorated function
# to the combined queue
before, after = None, None
- if hasattr(method, '_sa_instrument_before'):
+ if hasattr(method, "_sa_instrument_before"):
op, argument = method._sa_instrument_before
- assert op in ('fire_append_event', 'fire_remove_event')
+ assert op in ("fire_append_event", "fire_remove_event")
before = op, argument
- if hasattr(method, '_sa_instrument_after'):
+ if hasattr(method, "_sa_instrument_after"):
op = method._sa_instrument_after
- assert op in ('fire_append_event', 'fire_remove_event')
+ assert op in ("fire_append_event", "fire_remove_event")
after = op
if before:
- methods[name] = before + (after, )
+ methods[name] = before + (after,)
elif after:
methods[name] = None, None, after
return roles, methods
@@ -898,8 +924,11 @@ def _setup_canned_roles(cls, roles, methods):
# apply ABC auto-decoration to methods that need it
for method, decorator in decorators.items():
fn = getattr(cls, method, None)
- if (fn and method not in methods and
- not hasattr(fn, '_sa_instrumented')):
+ if (
+ fn
+ and method not in methods
+ and not hasattr(fn, "_sa_instrumented")
+ ):
setattr(cls, method, decorator(fn))
@@ -908,26 +937,31 @@ def _assert_required_roles(cls, roles, methods):
needed
"""
- if 'appender' not in roles or not hasattr(cls, roles['appender']):
+ if "appender" not in roles or not hasattr(cls, roles["appender"]):
raise sa_exc.ArgumentError(
"Type %s must elect an appender method to be "
- "a collection class" % cls.__name__)
- elif (roles['appender'] not in methods and
- not hasattr(getattr(cls, roles['appender']), '_sa_instrumented')):
- methods[roles['appender']] = ('fire_append_event', 1, None)
-
- if 'remover' not in roles or not hasattr(cls, roles['remover']):
+ "a collection class" % cls.__name__
+ )
+ elif roles["appender"] not in methods and not hasattr(
+ getattr(cls, roles["appender"]), "_sa_instrumented"
+ ):
+ methods[roles["appender"]] = ("fire_append_event", 1, None)
+
+ if "remover" not in roles or not hasattr(cls, roles["remover"]):
raise sa_exc.ArgumentError(
"Type %s must elect a remover method to be "
- "a collection class" % cls.__name__)
- elif (roles['remover'] not in methods and
- not hasattr(getattr(cls, roles['remover']), '_sa_instrumented')):
- methods[roles['remover']] = ('fire_remove_event', 1, None)
-
- if 'iterator' not in roles or not hasattr(cls, roles['iterator']):
+ "a collection class" % cls.__name__
+ )
+ elif roles["remover"] not in methods and not hasattr(
+ getattr(cls, roles["remover"]), "_sa_instrumented"
+ ):
+ methods[roles["remover"]] = ("fire_remove_event", 1, None)
+
+ if "iterator" not in roles or not hasattr(cls, roles["iterator"]):
raise sa_exc.ArgumentError(
"Type %s must elect an iterator method to be "
- "a collection class" % cls.__name__)
+ "a collection class" % cls.__name__
+ )
def _set_collection_attributes(cls, roles, methods):
@@ -936,16 +970,20 @@ def _set_collection_attributes(cls, roles, methods):
"""
for method_name, (before, argument, after) in methods.items():
- setattr(cls, method_name,
- _instrument_membership_mutator(getattr(cls, method_name),
- before, argument, after))
+ setattr(
+ cls,
+ method_name,
+ _instrument_membership_mutator(
+ getattr(cls, method_name), before, argument, after
+ ),
+ )
# intern the role map
for role, method_name in roles.items():
- setattr(cls, '_sa_%s' % role, getattr(cls, method_name))
+ setattr(cls, "_sa_%s" % role, getattr(cls, method_name))
cls._sa_adapter = None
- if not hasattr(cls, '_sa_converter'):
+ if not hasattr(cls, "_sa_converter"):
cls._sa_converter = None
cls._sa_instrumented = id(cls)
@@ -972,7 +1010,8 @@ def _instrument_membership_mutator(method, before, argument, after):
if pos_arg is None:
if named_arg not in kw:
raise sa_exc.ArgumentError(
- "Missing argument %s" % argument)
+ "Missing argument %s" % argument
+ )
value = kw[named_arg]
else:
if len(args) > pos_arg:
@@ -981,9 +1020,10 @@ def _instrument_membership_mutator(method, before, argument, after):
value = kw[named_arg]
else:
raise sa_exc.ArgumentError(
- "Missing argument %s" % argument)
+ "Missing argument %s" % argument
+ )
- initiator = kw.pop('_sa_initiator', None)
+ initiator = kw.pop("_sa_initiator", None)
if initiator is False:
executor = None
else:
@@ -1055,6 +1095,7 @@ def _list_decorators():
def append(self, item, _sa_initiator=None):
item = __set(self, item, _sa_initiator)
fn(self, item)
+
_tidy(append)
return append
@@ -1063,6 +1104,7 @@ def _list_decorators():
__del(self, value, _sa_initiator)
# testlib.pragma exempt:__eq__
fn(self, value)
+
_tidy(remove)
return remove
@@ -1070,6 +1112,7 @@ def _list_decorators():
def insert(self, index, value):
value = __set(self, value)
fn(self, index, value)
+
_tidy(insert)
return insert
@@ -1106,10 +1149,12 @@ def _list_decorators():
if len(value) != len(rng):
raise ValueError(
"attempt to assign sequence of size %s to "
- "extended slice of size %s" % (len(value),
- len(rng)))
+ "extended slice of size %s"
+ % (len(value), len(rng))
+ )
for i, item in zip(rng, value):
self.__setitem__(i, item)
+
_tidy(__setitem__)
return __setitem__
@@ -1126,16 +1171,19 @@ def _list_decorators():
for item in self[index]:
__del(self, item)
fn(self, index)
+
_tidy(__delitem__)
return __delitem__
if util.py2k:
+
def __setslice__(fn):
def __setslice__(self, start, end, values):
for value in self[start:end]:
__del(self, value)
values = [__set(self, value) for value in values]
fn(self, start, end, values)
+
_tidy(__setslice__)
return __setslice__
@@ -1144,6 +1192,7 @@ def _list_decorators():
for value in self[start:end]:
__del(self, value)
fn(self, start, end)
+
_tidy(__delslice__)
return __delslice__
@@ -1151,6 +1200,7 @@ def _list_decorators():
def extend(self, iterable):
for value in iterable:
self.append(value)
+
_tidy(extend)
return extend
@@ -1161,6 +1211,7 @@ def _list_decorators():
for value in iterable:
self.append(value)
return self
+
_tidy(__iadd__)
return __iadd__
@@ -1170,15 +1221,18 @@ def _list_decorators():
item = fn(self, index)
__del(self, item)
return item
+
_tidy(pop)
return pop
if not util.py2k:
+
def clear(fn):
def clear(self, index=-1):
for item in self:
__del(self, item)
fn(self)
+
_tidy(clear)
return clear
@@ -1188,7 +1242,7 @@ def _list_decorators():
# desired. hard to imagine a use case for __imul__, though.
l = locals().copy()
- l.pop('_tidy')
+ l.pop("_tidy")
return l
@@ -1199,7 +1253,7 @@ def _dict_decorators():
fn._sa_instrumented = True
fn.__doc__ = getattr(dict, fn.__name__).__doc__
- Unspecified = util.symbol('Unspecified')
+ Unspecified = util.symbol("Unspecified")
def __setitem__(fn):
def __setitem__(self, key, value, _sa_initiator=None):
@@ -1207,6 +1261,7 @@ def _dict_decorators():
__del(self, self[key], _sa_initiator)
value = __set(self, value, _sa_initiator)
fn(self, key, value)
+
_tidy(__setitem__)
return __setitem__
@@ -1215,6 +1270,7 @@ def _dict_decorators():
if key in self:
__del(self, self[key], _sa_initiator)
fn(self, key)
+
_tidy(__delitem__)
return __delitem__
@@ -1223,6 +1279,7 @@ def _dict_decorators():
for key in self:
__del(self, self[key])
fn(self)
+
_tidy(clear)
return clear
@@ -1237,6 +1294,7 @@ def _dict_decorators():
if _to_del:
__del(self, item)
return item
+
_tidy(pop)
return pop
@@ -1246,6 +1304,7 @@ def _dict_decorators():
item = fn(self)
__del(self, item[1])
return item
+
_tidy(popitem)
return popitem
@@ -1256,16 +1315,16 @@ def _dict_decorators():
return default
else:
return self.__getitem__(key)
+
_tidy(setdefault)
return setdefault
def update(fn):
def update(self, __other=Unspecified, **kw):
if __other is not Unspecified:
- if hasattr(__other, 'keys'):
+ if hasattr(__other, "keys"):
for key in list(__other):
- if (key not in self or
- self[key] is not __other[key]):
+ if key not in self or self[key] is not __other[key]:
self[key] = __other[key]
else:
for key, value in __other:
@@ -1274,14 +1333,16 @@ def _dict_decorators():
for key in kw:
if key not in self or self[key] is not kw[key]:
self[key] = kw[key]
+
_tidy(update)
return update
l = locals().copy()
- l.pop('_tidy')
- l.pop('Unspecified')
+ l.pop("_tidy")
+ l.pop("Unspecified")
return l
+
_set_binop_bases = (set, frozenset)
@@ -1293,8 +1354,10 @@ def _set_binops_check_strict(self, obj):
def _set_binops_check_loose(self, obj):
"""Allow anything set-like to participate in set binops."""
- return (isinstance(obj, _set_binop_bases + (self.__class__,)) or
- util.duck_type_collection(obj) == set)
+ return (
+ isinstance(obj, _set_binop_bases + (self.__class__,))
+ or util.duck_type_collection(obj) == set
+ )
def _set_decorators():
@@ -1304,7 +1367,7 @@ def _set_decorators():
fn._sa_instrumented = True
fn.__doc__ = getattr(set, fn.__name__).__doc__
- Unspecified = util.symbol('Unspecified')
+ Unspecified = util.symbol("Unspecified")
def add(fn):
def add(self, value, _sa_initiator=None):
@@ -1312,6 +1375,7 @@ def _set_decorators():
value = __set(self, value, _sa_initiator)
# testlib.pragma exempt:__hash__
fn(self, value)
+
_tidy(add)
return add
@@ -1322,6 +1386,7 @@ def _set_decorators():
__del(self, value, _sa_initiator)
# testlib.pragma exempt:__hash__
fn(self, value)
+
_tidy(discard)
return discard
@@ -1332,6 +1397,7 @@ def _set_decorators():
__del(self, value, _sa_initiator)
# testlib.pragma exempt:__hash__
fn(self, value)
+
_tidy(remove)
return remove
@@ -1343,6 +1409,7 @@ def _set_decorators():
# that will be popped before pop is called.
__del(self, item)
return item
+
_tidy(pop)
return pop
@@ -1350,6 +1417,7 @@ def _set_decorators():
def clear(self):
for item in list(self):
self.remove(item)
+
_tidy(clear)
return clear
@@ -1357,6 +1425,7 @@ def _set_decorators():
def update(self, value):
for item in value:
self.add(item)
+
_tidy(update)
return update
@@ -1367,6 +1436,7 @@ def _set_decorators():
for item in value:
self.add(item)
return self
+
_tidy(__ior__)
return __ior__
@@ -1374,6 +1444,7 @@ def _set_decorators():
def difference_update(self, value):
for item in value:
self.discard(item)
+
_tidy(difference_update)
return difference_update
@@ -1384,6 +1455,7 @@ def _set_decorators():
for item in value:
self.discard(item)
return self
+
_tidy(__isub__)
return __isub__
@@ -1396,6 +1468,7 @@ def _set_decorators():
self.remove(item)
for item in add:
self.add(item)
+
_tidy(intersection_update)
return intersection_update
@@ -1411,6 +1484,7 @@ def _set_decorators():
for item in add:
self.add(item)
return self
+
_tidy(__iand__)
return __iand__
@@ -1423,6 +1497,7 @@ def _set_decorators():
self.remove(item)
for item in add:
self.add(item)
+
_tidy(symmetric_difference_update)
return symmetric_difference_update
@@ -1438,12 +1513,13 @@ def _set_decorators():
for item in add:
self.add(item)
return self
+
_tidy(__ixor__)
return __ixor__
l = locals().copy()
- l.pop('_tidy')
- l.pop('Unspecified')
+ l.pop("_tidy")
+ l.pop("Unspecified")
return l
@@ -1467,18 +1543,17 @@ __canned_instrumentation = {
__interfaces = {
list: (
- {'appender': 'append', 'remover': 'remove',
- 'iterator': '__iter__'}, _list_decorators()
+ {"appender": "append", "remover": "remove", "iterator": "__iter__"},
+ _list_decorators(),
+ ),
+ set: (
+ {"appender": "add", "remover": "remove", "iterator": "__iter__"},
+ _set_decorators(),
),
-
- set: ({'appender': 'add',
- 'remover': 'remove',
- 'iterator': '__iter__'}, _set_decorators()
- ),
-
# decorators are required for dicts and object collections.
- dict: ({'iterator': 'values'}, _dict_decorators()) if util.py3k
- else ({'iterator': 'itervalues'}, _dict_decorators()),
+ dict: ({"iterator": "values"}, _dict_decorators())
+ if util.py3k
+ else ({"iterator": "itervalues"}, _dict_decorators()),
}
@@ -1529,10 +1604,11 @@ class MappedCollection(dict):
"Can not remove '%s': collection holds '%s' for key '%s'. "
"Possible cause: is the MappedCollection key function "
"based on mutable properties or properties that only obtain "
- "values after flush?" %
- (value, self[key], key))
+ "values after flush?" % (value, self[key], key)
+ )
self.__delitem__(key, _sa_initiator)
+
# ensure instrumentation is associated with
# these built-in classes; if a user-defined class
# subclasses these and uses @internally_instrumented,
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index 960b9e5d5..cba4d2141 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -10,8 +10,7 @@
"""
from .. import sql, util, exc as sa_exc
-from . import attributes, exc, sync, unitofwork, \
- util as mapperutil
+from . import attributes, exc, sync, unitofwork, util as mapperutil
from .interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
@@ -41,8 +40,8 @@ class DependencyProcessor(object):
raise sa_exc.ArgumentError(
"Can't build a DependencyProcessor for relationship %s. "
"No target attributes to populate between parent and "
- "child are present" %
- self.prop)
+ "child are present" % self.prop
+ )
@classmethod
def from_relationship(cls, prop):
@@ -70,31 +69,28 @@ class DependencyProcessor(object):
before_delete = unitofwork.ProcessAll(uow, self, True, True)
parent_saves = unitofwork.SaveUpdateAll(
- uow,
- self.parent.primary_base_mapper
+ uow, self.parent.primary_base_mapper
)
child_saves = unitofwork.SaveUpdateAll(
- uow,
- self.mapper.primary_base_mapper
+ uow, self.mapper.primary_base_mapper
)
parent_deletes = unitofwork.DeleteAll(
- uow,
- self.parent.primary_base_mapper
+ uow, self.parent.primary_base_mapper
)
child_deletes = unitofwork.DeleteAll(
- uow,
- self.mapper.primary_base_mapper
+ uow, self.mapper.primary_base_mapper
)
- self.per_property_dependencies(uow,
- parent_saves,
- child_saves,
- parent_deletes,
- child_deletes,
- after_save,
- before_delete
- )
+ self.per_property_dependencies(
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ )
def per_state_flush_actions(self, uow, states, isdelete):
"""establish actions and dependencies related to a flush.
@@ -130,9 +126,7 @@ class DependencyProcessor(object):
# child side is not part of the cycle, so we will link per-state
# actions to the aggregate "saves", "deletes" actions
- child_actions = [
- (child_saves, False), (child_deletes, True)
- ]
+ child_actions = [(child_saves, False), (child_deletes, True)]
child_in_cycles = False
else:
child_in_cycles = True
@@ -140,15 +134,13 @@ class DependencyProcessor(object):
# check if the "parent" side is part of the cycle
if not isdelete:
parent_saves = unitofwork.SaveUpdateAll(
- uow,
- self.parent.base_mapper)
+ uow, self.parent.base_mapper
+ )
parent_deletes = before_delete = None
if parent_saves in uow.cycles:
parent_in_cycles = True
else:
- parent_deletes = unitofwork.DeleteAll(
- uow,
- self.parent.base_mapper)
+ parent_deletes = unitofwork.DeleteAll(uow, self.parent.base_mapper)
parent_saves = after_save = None
if parent_deletes in uow.cycles:
parent_in_cycles = True
@@ -160,17 +152,18 @@ class DependencyProcessor(object):
# by a preprocessor on this state/attribute. In the
# case of deletes we may try to load missing items here as well.
sum_ = state.manager[self.key].impl.get_all_pending(
- state, state.dict,
+ state,
+ state.dict,
self._passive_delete_flag
if isdelete
- else attributes.PASSIVE_NO_INITIALIZE)
+ else attributes.PASSIVE_NO_INITIALIZE,
+ )
if not sum_:
continue
if isdelete:
- before_delete = unitofwork.ProcessState(uow,
- self, True, state)
+ before_delete = unitofwork.ProcessState(uow, self, True, state)
if parent_in_cycles:
parent_deletes = unitofwork.DeleteState(uow, state)
else:
@@ -188,21 +181,28 @@ class DependencyProcessor(object):
if deleted:
child_action = (
unitofwork.DeleteState(uow, child_state),
- True)
+ True,
+ )
else:
child_action = (
unitofwork.SaveUpdateState(uow, child_state),
- False)
+ False,
+ )
child_actions.append(child_action)
# establish dependencies between our possibly per-state
# parent action and our possibly per-state child action.
for child_action, childisdelete in child_actions:
- self.per_state_dependencies(uow, parent_saves,
- parent_deletes,
- child_action,
- after_save, before_delete,
- isdelete, childisdelete)
+ self.per_state_dependencies(
+ uow,
+ parent_saves,
+ parent_deletes,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ )
def presort_deletes(self, uowcommit, states):
return False
@@ -228,76 +228,74 @@ class DependencyProcessor(object):
# TODO: add a high speed method
# to InstanceState which returns: attribute
# has a non-None value, or had one
- history = uowcommit.get_attribute_history(
- s,
- self.key,
- passive)
+ history = uowcommit.get_attribute_history(s, self.key, passive)
if history and not history.empty():
return True
else:
- return states and \
- not self.prop._is_self_referential and \
- self.mapper in uowcommit.mappers
+ return (
+ states
+ and not self.prop._is_self_referential
+ and self.mapper in uowcommit.mappers
+ )
def _verify_canload(self, state):
if self.prop.uselist and state is None:
raise exc.FlushError(
"Can't flush None value found in "
- "collection %s" % (self.prop, ))
- elif state is not None and \
- not self.mapper._canload(
- state, allow_subtypes=not self.enable_typechecks):
+ "collection %s" % (self.prop,)
+ )
+ elif state is not None and not self.mapper._canload(
+ state, allow_subtypes=not self.enable_typechecks
+ ):
if self.mapper._canload(state, allow_subtypes=True):
- raise exc.FlushError('Attempting to flush an item of type '
- '%(x)s as a member of collection '
- '"%(y)s". Expected an object of type '
- '%(z)s or a polymorphic subclass of '
- 'this type. If %(x)s is a subclass of '
- '%(z)s, configure mapper "%(zm)s" to '
- 'load this subtype polymorphically, or '
- 'set enable_typechecks=False to allow '
- 'any subtype to be accepted for flush. '
- % {
- 'x': state.class_,
- 'y': self.prop,
- 'z': self.mapper.class_,
- 'zm': self.mapper,
- })
+ raise exc.FlushError(
+ "Attempting to flush an item of type "
+ "%(x)s as a member of collection "
+ '"%(y)s". Expected an object of type '
+ "%(z)s or a polymorphic subclass of "
+ "this type. If %(x)s is a subclass of "
+ '%(z)s, configure mapper "%(zm)s" to '
+ "load this subtype polymorphically, or "
+ "set enable_typechecks=False to allow "
+ "any subtype to be accepted for flush. "
+ % {
+ "x": state.class_,
+ "y": self.prop,
+ "z": self.mapper.class_,
+ "zm": self.mapper,
+ }
+ )
else:
raise exc.FlushError(
- 'Attempting to flush an item of type '
- '%(x)s as a member of collection '
+ "Attempting to flush an item of type "
+ "%(x)s as a member of collection "
'"%(y)s". Expected an object of type '
- '%(z)s or a polymorphic subclass of '
- 'this type.' % {
- 'x': state.class_,
- 'y': self.prop,
- 'z': self.mapper.class_,
- })
-
- def _synchronize(self, state, child, associationrow,
- clearkeys, uowcommit):
+ "%(z)s or a polymorphic subclass of "
+ "this type."
+ % {
+ "x": state.class_,
+ "y": self.prop,
+ "z": self.mapper.class_,
+ }
+ )
+
+ def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
raise NotImplementedError()
def _get_reversed_processed_set(self, uow):
if not self.prop._reverse_property:
return None
- process_key = tuple(sorted(
- [self.key] +
- [p.key for p in self.prop._reverse_property]
- ))
- return uow.memo(
- ('reverse_key', process_key),
- set
+ process_key = tuple(
+ sorted([self.key] + [p.key for p in self.prop._reverse_property])
)
+ return uow.memo(("reverse_key", process_key), set)
def _post_update(self, state, uowcommit, related, is_m2o_delete=False):
for x in related:
if not is_m2o_delete or x is not None:
uowcommit.register_post_update(
- state,
- [r for l, r in self.prop.synchronize_pairs]
+ state, [r for l, r in self.prop.synchronize_pairs]
)
break
@@ -309,114 +307,126 @@ class DependencyProcessor(object):
class OneToManyDP(DependencyProcessor):
-
- def per_property_dependencies(self, uow, parent_saves,
- child_saves,
- parent_deletes,
- child_deletes,
- after_save,
- before_delete,
- ):
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
if self.post_update:
child_post_updates = unitofwork.PostUpdateAll(
- uow,
- self.mapper.primary_base_mapper,
- False)
+ uow, self.mapper.primary_base_mapper, False
+ )
child_pre_updates = unitofwork.PostUpdateAll(
- uow,
- self.mapper.primary_base_mapper,
- True)
-
- uow.dependencies.update([
- (child_saves, after_save),
- (parent_saves, after_save),
- (after_save, child_post_updates),
-
- (before_delete, child_pre_updates),
- (child_pre_updates, parent_deletes),
- (child_pre_updates, child_deletes),
-
- ])
+ uow, self.mapper.primary_base_mapper, True
+ )
+
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (parent_saves, after_save),
+ (after_save, child_post_updates),
+ (before_delete, child_pre_updates),
+ (child_pre_updates, parent_deletes),
+ (child_pre_updates, child_deletes),
+ ]
+ )
else:
- uow.dependencies.update([
- (parent_saves, after_save),
- (after_save, child_saves),
- (after_save, child_deletes),
-
- (child_saves, parent_deletes),
- (child_deletes, parent_deletes),
-
- (before_delete, child_saves),
- (before_delete, child_deletes),
- ])
-
- def per_state_dependencies(self, uow,
- save_parent,
- delete_parent,
- child_action,
- after_save, before_delete,
- isdelete, childisdelete):
+ uow.dependencies.update(
+ [
+ (parent_saves, after_save),
+ (after_save, child_saves),
+ (after_save, child_deletes),
+ (child_saves, parent_deletes),
+ (child_deletes, parent_deletes),
+ (before_delete, child_saves),
+ (before_delete, child_deletes),
+ ]
+ )
+
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
if self.post_update:
child_post_updates = unitofwork.PostUpdateAll(
- uow,
- self.mapper.primary_base_mapper,
- False)
+ uow, self.mapper.primary_base_mapper, False
+ )
child_pre_updates = unitofwork.PostUpdateAll(
- uow,
- self.mapper.primary_base_mapper,
- True)
+ uow, self.mapper.primary_base_mapper, True
+ )
# TODO: this whole block is not covered
# by any tests
if not isdelete:
if childisdelete:
- uow.dependencies.update([
- (child_action, after_save),
- (after_save, child_post_updates),
- ])
+ uow.dependencies.update(
+ [
+ (child_action, after_save),
+ (after_save, child_post_updates),
+ ]
+ )
else:
- uow.dependencies.update([
- (save_parent, after_save),
- (child_action, after_save),
- (after_save, child_post_updates),
- ])
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (child_action, after_save),
+ (after_save, child_post_updates),
+ ]
+ )
else:
if childisdelete:
- uow.dependencies.update([
- (before_delete, child_pre_updates),
- (child_pre_updates, delete_parent),
- ])
+ uow.dependencies.update(
+ [
+ (before_delete, child_pre_updates),
+ (child_pre_updates, delete_parent),
+ ]
+ )
else:
- uow.dependencies.update([
- (before_delete, child_pre_updates),
- (child_pre_updates, delete_parent),
- ])
+ uow.dependencies.update(
+ [
+ (before_delete, child_pre_updates),
+ (child_pre_updates, delete_parent),
+ ]
+ )
elif not isdelete:
- uow.dependencies.update([
- (save_parent, after_save),
- (after_save, child_action),
- (save_parent, child_action)
- ])
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (after_save, child_action),
+ (save_parent, child_action),
+ ]
+ )
else:
- uow.dependencies.update([
- (before_delete, child_action),
- (child_action, delete_parent)
- ])
+ uow.dependencies.update(
+ [(before_delete, child_action), (child_action, delete_parent)]
+ )
def presort_deletes(self, uowcommit, states):
# head object is being deleted, and we manage its list of
# child objects the child objects have to have their
# foreign key to the parent set to NULL
- should_null_fks = not self.cascade.delete and \
- not self.passive_deletes == 'all'
+ should_null_fks = (
+ not self.cascade.delete and not self.passive_deletes == "all"
+ )
for state in states:
history = uowcommit.get_attribute_history(
- state,
- self.key,
- self._passive_delete_flag)
+ state, self.key, self._passive_delete_flag
+ )
if history:
for child in history.deleted:
if child is not None and self.hasparent(child) is False:
@@ -429,13 +439,16 @@ class OneToManyDP(DependencyProcessor):
for child in history.unchanged:
if child is not None:
uowcommit.register_object(
- child, operation="delete", prop=self.prop)
+ child, operation="delete", prop=self.prop
+ )
def presort_saves(self, uowcommit, states):
- children_added = uowcommit.memo(('children_added', self), set)
+ children_added = uowcommit.memo(("children_added", self), set)
- should_null_fks = not self.cascade.delete_orphan and \
- not self.passive_deletes == 'all'
+ should_null_fks = (
+ not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ )
for state in states:
pks_changed = self._pks_changed(uowcommit, state)
@@ -445,34 +458,39 @@ class OneToManyDP(DependencyProcessor):
else:
passive = attributes.PASSIVE_OFF
- history = uowcommit.get_attribute_history(
- state,
- self.key,
- passive)
+ history = uowcommit.get_attribute_history(state, self.key, passive)
if history:
for child in history.added:
if child is not None:
- uowcommit.register_object(child, cancel_delete=True,
- operation="add",
- prop=self.prop)
+ uowcommit.register_object(
+ child,
+ cancel_delete=True,
+ operation="add",
+ prop=self.prop,
+ )
children_added.update(history.added)
for child in history.deleted:
if not self.cascade.delete_orphan:
if should_null_fks:
- uowcommit.register_object(child, isdelete=False,
- operation='delete',
- prop=self.prop)
+ uowcommit.register_object(
+ child,
+ isdelete=False,
+ operation="delete",
+ prop=self.prop,
+ )
elif self.hasparent(child) is False:
uowcommit.register_object(
- child, isdelete=True,
- operation="delete", prop=self.prop)
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
for c, m, st_, dct_ in self.mapper.cascade_iterator(
- 'delete', child):
- uowcommit.register_object(
- st_,
- isdelete=True)
+ "delete", child
+ ):
+ uowcommit.register_object(st_, isdelete=True)
if pks_changed:
if history:
@@ -483,7 +501,8 @@ class OneToManyDP(DependencyProcessor):
False,
self.passive_updates,
operation="pk change",
- prop=self.prop)
+ prop=self.prop,
+ )
def process_deletes(self, uowcommit, states):
# head object is being deleted, and we manage its list of
@@ -492,39 +511,37 @@ class OneToManyDP(DependencyProcessor):
# safely for any cascade but is unnecessary if delete cascade
# is on.
- if self.post_update or not self.passive_deletes == 'all':
- children_added = uowcommit.memo(('children_added', self), set)
+ if self.post_update or not self.passive_deletes == "all":
+ children_added = uowcommit.memo(("children_added", self), set)
for state in states:
history = uowcommit.get_attribute_history(
- state,
- self.key,
- self._passive_delete_flag)
+ state, self.key, self._passive_delete_flag
+ )
if history:
for child in history.deleted:
- if child is not None and \
- self.hasparent(child) is False:
+ if (
+ child is not None
+ and self.hasparent(child) is False
+ ):
self._synchronize(
- state,
- child,
- None, True,
- uowcommit, False)
+ state, child, None, True, uowcommit, False
+ )
if self.post_update and child:
self._post_update(child, uowcommit, [state])
if self.post_update or not self.cascade.delete:
- for child in set(history.unchanged).\
- difference(children_added):
+ for child in set(history.unchanged).difference(
+ children_added
+ ):
if child is not None:
self._synchronize(
- state,
- child,
- None, True,
- uowcommit, False)
+ state, child, None, True, uowcommit, False
+ )
if self.post_update and child:
- self._post_update(child,
- uowcommit,
- [state])
+ self._post_update(
+ child, uowcommit, [state]
+ )
# technically, we can even remove each child from the
# collection here too. but this would be a somewhat
@@ -532,54 +549,66 @@ class OneToManyDP(DependencyProcessor):
# if the old parent wasn't deleted but child was moved.
def process_saves(self, uowcommit, states):
- should_null_fks = not self.cascade.delete_orphan and \
- not self.passive_deletes == 'all'
+ should_null_fks = (
+ not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ )
for state in states:
history = uowcommit.get_attribute_history(
- state,
- self.key,
- attributes.PASSIVE_NO_INITIALIZE)
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
if history:
for child in history.added:
- self._synchronize(state, child, None,
- False, uowcommit, False)
+ self._synchronize(
+ state, child, None, False, uowcommit, False
+ )
if child is not None and self.post_update:
self._post_update(child, uowcommit, [state])
for child in history.deleted:
- if should_null_fks and not self.cascade.delete_orphan and \
- not self.hasparent(child):
- self._synchronize(state, child, None, True,
- uowcommit, False)
+ if (
+ should_null_fks
+ and not self.cascade.delete_orphan
+ and not self.hasparent(child)
+ ):
+ self._synchronize(
+ state, child, None, True, uowcommit, False
+ )
if self._pks_changed(uowcommit, state):
for child in history.unchanged:
- self._synchronize(state, child, None,
- False, uowcommit, True)
+ self._synchronize(
+ state, child, None, False, uowcommit, True
+ )
- def _synchronize(self, state, child,
- associationrow, clearkeys, uowcommit,
- pks_changed):
+ def _synchronize(
+ self, state, child, associationrow, clearkeys, uowcommit, pks_changed
+ ):
source = state
dest = child
self._verify_canload(child)
- if dest is None or \
- (not self.post_update and uowcommit.is_deleted(dest)):
+ if dest is None or (
+ not self.post_update and uowcommit.is_deleted(dest)
+ ):
return
if clearkeys:
sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
else:
- sync.populate(source, self.parent, dest, self.mapper,
- self.prop.synchronize_pairs, uowcommit,
- self.passive_updates and pks_changed)
+ sync.populate(
+ source,
+ self.parent,
+ dest,
+ self.mapper,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ self.passive_updates and pks_changed,
+ )
def _pks_changed(self, uowcommit, state):
return sync.source_modified(
- uowcommit,
- state,
- self.parent,
- self.prop.synchronize_pairs)
+ uowcommit, state, self.parent, self.prop.synchronize_pairs
+ )
class ManyToOneDP(DependencyProcessor):
@@ -587,105 +616,110 @@ class ManyToOneDP(DependencyProcessor):
DependencyProcessor.__init__(self, prop)
self.mapper._dependency_processors.append(DetectKeySwitch(prop))
- def per_property_dependencies(self, uow,
- parent_saves,
- child_saves,
- parent_deletes,
- child_deletes,
- after_save,
- before_delete):
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
if self.post_update:
parent_post_updates = unitofwork.PostUpdateAll(
- uow,
- self.parent.primary_base_mapper,
- False)
+ uow, self.parent.primary_base_mapper, False
+ )
parent_pre_updates = unitofwork.PostUpdateAll(
- uow,
- self.parent.primary_base_mapper,
- True)
-
- uow.dependencies.update([
- (child_saves, after_save),
- (parent_saves, after_save),
- (after_save, parent_post_updates),
-
- (after_save, parent_pre_updates),
- (before_delete, parent_pre_updates),
-
- (parent_pre_updates, child_deletes),
- (parent_pre_updates, parent_deletes),
- ])
+ uow, self.parent.primary_base_mapper, True
+ )
+
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (parent_saves, after_save),
+ (after_save, parent_post_updates),
+ (after_save, parent_pre_updates),
+ (before_delete, parent_pre_updates),
+ (parent_pre_updates, child_deletes),
+ (parent_pre_updates, parent_deletes),
+ ]
+ )
else:
- uow.dependencies.update([
- (child_saves, after_save),
- (after_save, parent_saves),
- (parent_saves, child_deletes),
- (parent_deletes, child_deletes)
- ])
-
- def per_state_dependencies(self, uow,
- save_parent,
- delete_parent,
- child_action,
- after_save, before_delete,
- isdelete, childisdelete):
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (after_save, parent_saves),
+ (parent_saves, child_deletes),
+ (parent_deletes, child_deletes),
+ ]
+ )
+
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
if self.post_update:
if not isdelete:
parent_post_updates = unitofwork.PostUpdateAll(
- uow,
- self.parent.primary_base_mapper,
- False)
+ uow, self.parent.primary_base_mapper, False
+ )
if childisdelete:
- uow.dependencies.update([
- (after_save, parent_post_updates),
- (parent_post_updates, child_action)
- ])
+ uow.dependencies.update(
+ [
+ (after_save, parent_post_updates),
+ (parent_post_updates, child_action),
+ ]
+ )
else:
- uow.dependencies.update([
- (save_parent, after_save),
- (child_action, after_save),
-
- (after_save, parent_post_updates)
- ])
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (child_action, after_save),
+ (after_save, parent_post_updates),
+ ]
+ )
else:
parent_pre_updates = unitofwork.PostUpdateAll(
- uow,
- self.parent.primary_base_mapper,
- True)
+ uow, self.parent.primary_base_mapper, True
+ )
- uow.dependencies.update([
- (before_delete, parent_pre_updates),
- (parent_pre_updates, delete_parent),
- (parent_pre_updates, child_action)
- ])
+ uow.dependencies.update(
+ [
+ (before_delete, parent_pre_updates),
+ (parent_pre_updates, delete_parent),
+ (parent_pre_updates, child_action),
+ ]
+ )
elif not isdelete:
if not childisdelete:
- uow.dependencies.update([
- (child_action, after_save),
- (after_save, save_parent),
- ])
+ uow.dependencies.update(
+ [(child_action, after_save), (after_save, save_parent)]
+ )
else:
- uow.dependencies.update([
- (after_save, save_parent),
- ])
+ uow.dependencies.update([(after_save, save_parent)])
else:
if childisdelete:
- uow.dependencies.update([
- (delete_parent, child_action)
- ])
+ uow.dependencies.update([(delete_parent, child_action)])
def presort_deletes(self, uowcommit, states):
if self.cascade.delete or self.cascade.delete_orphan:
for state in states:
history = uowcommit.get_attribute_history(
- state,
- self.key,
- self._passive_delete_flag)
+ state, self.key, self._passive_delete_flag
+ )
if history:
if self.cascade.delete_orphan:
todelete = history.sum()
@@ -695,36 +729,42 @@ class ManyToOneDP(DependencyProcessor):
if child is None:
continue
uowcommit.register_object(
- child, isdelete=True,
- operation="delete", prop=self.prop)
- t = self.mapper.cascade_iterator('delete', child)
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+ t = self.mapper.cascade_iterator("delete", child)
for c, m, st_, dct_ in t:
- uowcommit.register_object(
- st_, isdelete=True)
+ uowcommit.register_object(st_, isdelete=True)
def presort_saves(self, uowcommit, states):
for state in states:
uowcommit.register_object(state, operation="add", prop=self.prop)
if self.cascade.delete_orphan:
history = uowcommit.get_attribute_history(
- state,
- self.key,
- self._passive_delete_flag)
+ state, self.key, self._passive_delete_flag
+ )
if history:
for child in history.deleted:
if self.hasparent(child) is False:
uowcommit.register_object(
- child, isdelete=True,
- operation="delete", prop=self.prop)
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
- t = self.mapper.cascade_iterator('delete', child)
+ t = self.mapper.cascade_iterator("delete", child)
for c, m, st_, dct_ in t:
uowcommit.register_object(st_, isdelete=True)
def process_deletes(self, uowcommit, states):
- if self.post_update and \
- not self.cascade.delete_orphan and \
- not self.passive_deletes == 'all':
+ if (
+ self.post_update
+ and not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ ):
# post_update means we have to update our
# row to not reference the child object
@@ -733,55 +773,70 @@ class ManyToOneDP(DependencyProcessor):
self._synchronize(state, None, None, True, uowcommit)
if state and self.post_update:
history = uowcommit.get_attribute_history(
- state,
- self.key,
- self._passive_delete_flag)
+ state, self.key, self._passive_delete_flag
+ )
if history:
self._post_update(
- state, uowcommit, history.sum(),
- is_m2o_delete=True)
+ state, uowcommit, history.sum(), is_m2o_delete=True
+ )
def process_saves(self, uowcommit, states):
for state in states:
history = uowcommit.get_attribute_history(
- state,
- self.key,
- attributes.PASSIVE_NO_INITIALIZE)
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
if history:
if history.added:
for child in history.added:
- self._synchronize(state, child, None, False,
- uowcommit, "add")
+ self._synchronize(
+ state, child, None, False, uowcommit, "add"
+ )
elif history.deleted:
self._synchronize(
- state, None, None, True, uowcommit, "delete")
+ state, None, None, True, uowcommit, "delete"
+ )
if self.post_update:
self._post_update(state, uowcommit, history.sum())
- def _synchronize(self, state, child, associationrow,
- clearkeys, uowcommit, operation=None):
- if state is None or \
- (not self.post_update and uowcommit.is_deleted(state)):
+ def _synchronize(
+ self,
+ state,
+ child,
+ associationrow,
+ clearkeys,
+ uowcommit,
+ operation=None,
+ ):
+ if state is None or (
+ not self.post_update and uowcommit.is_deleted(state)
+ ):
return
- if operation is not None and \
- child is not None and \
- not uowcommit.session._contains_state(child):
+ if (
+ operation is not None
+ and child is not None
+ and not uowcommit.session._contains_state(child)
+ ):
util.warn(
"Object of type %s not in session, %s "
- "operation along '%s' won't proceed" %
- (mapperutil.state_class_str(child), operation, self.prop))
+ "operation along '%s' won't proceed"
+ % (mapperutil.state_class_str(child), operation, self.prop)
+ )
return
if clearkeys or child is None:
sync.clear(state, self.parent, self.prop.synchronize_pairs)
else:
self._verify_canload(child)
- sync.populate(child, self.mapper, state,
- self.parent,
- self.prop.synchronize_pairs,
- uowcommit,
- False)
+ sync.populate(
+ child,
+ self.mapper,
+ state,
+ self.parent,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ False,
+ )
class DetectKeySwitch(DependencyProcessor):
@@ -801,20 +856,18 @@ class DetectKeySwitch(DependencyProcessor):
if self.passive_updates:
return
else:
- if False in (prop.passive_updates for
- prop in self.prop._reverse_property):
+ if False in (
+ prop.passive_updates
+ for prop in self.prop._reverse_property
+ ):
return
uow.register_preprocessor(self, False)
def per_property_flush_actions(self, uow):
- parent_saves = unitofwork.SaveUpdateAll(
- uow,
- self.parent.base_mapper)
+ parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper)
after_save = unitofwork.ProcessAll(uow, self, False, False)
- uow.dependencies.update([
- (parent_saves, after_save)
- ])
+ uow.dependencies.update([(parent_saves, after_save)])
def per_state_flush_actions(self, uow, states, isdelete):
pass
@@ -848,8 +901,7 @@ class DetectKeySwitch(DependencyProcessor):
def _key_switchers(self, uow, states):
switched, notswitched = uow.memo(
- ('pk_switchers', self),
- lambda: (set(), set())
+ ("pk_switchers", self), lambda: (set(), set())
)
allstates = switched.union(notswitched)
@@ -871,74 +923,86 @@ class DetectKeySwitch(DependencyProcessor):
continue
dict_ = state.dict
related = state.get_impl(self.key).get(
- state, dict_, passive=self._passive_update_flag)
- if related is not attributes.PASSIVE_NO_RESULT and \
- related is not None:
+ state, dict_, passive=self._passive_update_flag
+ )
+ if (
+ related is not attributes.PASSIVE_NO_RESULT
+ and related is not None
+ ):
related_state = attributes.instance_state(dict_[self.key])
if related_state in switchers:
- uowcommit.register_object(state,
- False,
- self.passive_updates)
+ uowcommit.register_object(
+ state, False, self.passive_updates
+ )
sync.populate(
related_state,
- self.mapper, state,
- self.parent, self.prop.synchronize_pairs,
- uowcommit, self.passive_updates)
+ self.mapper,
+ state,
+ self.parent,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ self.passive_updates,
+ )
def _pks_changed(self, uowcommit, state):
return bool(state.key) and sync.source_modified(
- uowcommit, state, self.mapper, self.prop.synchronize_pairs)
+ uowcommit, state, self.mapper, self.prop.synchronize_pairs
+ )
class ManyToManyDP(DependencyProcessor):
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
+
+ uow.dependencies.update(
+ [
+ (parent_saves, after_save),
+ (child_saves, after_save),
+ (after_save, child_deletes),
+ # a rowswitch on the parent from deleted to saved
+ # can make this one occur, as the "save" may remove
+ # an element from the
+ # "deleted" list before we have a chance to
+ # process its child rows
+ (before_delete, parent_saves),
+ (before_delete, parent_deletes),
+ (before_delete, child_deletes),
+ (before_delete, child_saves),
+ ]
+ )
- def per_property_dependencies(self, uow, parent_saves,
- child_saves,
- parent_deletes,
- child_deletes,
- after_save,
- before_delete
- ):
-
- uow.dependencies.update([
- (parent_saves, after_save),
- (child_saves, after_save),
- (after_save, child_deletes),
-
- # a rowswitch on the parent from deleted to saved
- # can make this one occur, as the "save" may remove
- # an element from the
- # "deleted" list before we have a chance to
- # process its child rows
- (before_delete, parent_saves),
-
- (before_delete, parent_deletes),
- (before_delete, child_deletes),
- (before_delete, child_saves),
- ])
-
- def per_state_dependencies(self, uow,
- save_parent,
- delete_parent,
- child_action,
- after_save, before_delete,
- isdelete, childisdelete):
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
if not isdelete:
if childisdelete:
- uow.dependencies.update([
- (save_parent, after_save),
- (after_save, child_action),
- ])
+ uow.dependencies.update(
+ [(save_parent, after_save), (after_save, child_action)]
+ )
else:
- uow.dependencies.update([
- (save_parent, after_save),
- (child_action, after_save),
- ])
+ uow.dependencies.update(
+ [(save_parent, after_save), (child_action, after_save)]
+ )
else:
- uow.dependencies.update([
- (before_delete, child_action),
- (before_delete, delete_parent)
- ])
+ uow.dependencies.update(
+ [(before_delete, child_action), (before_delete, delete_parent)]
+ )
def presort_deletes(self, uowcommit, states):
# TODO: no tests fail if this whole
@@ -949,9 +1013,8 @@ class ManyToManyDP(DependencyProcessor):
# returns True
for state in states:
uowcommit.get_attribute_history(
- state,
- self.key,
- self._passive_delete_flag)
+ state, self.key, self._passive_delete_flag
+ )
def presort_saves(self, uowcommit, states):
if not self.passive_updates:
@@ -961,9 +1024,8 @@ class ManyToManyDP(DependencyProcessor):
for state in states:
if self._pks_changed(uowcommit, state):
history = uowcommit.get_attribute_history(
- state,
- self.key,
- attributes.PASSIVE_OFF)
+ state, self.key, attributes.PASSIVE_OFF
+ )
if not self.cascade.delete_orphan:
return
@@ -972,20 +1034,21 @@ class ManyToManyDP(DependencyProcessor):
# if delete_orphan check is turned on.
for state in states:
history = uowcommit.get_attribute_history(
- state,
- self.key,
- attributes.PASSIVE_NO_INITIALIZE)
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
if history:
for child in history.deleted:
if self.hasparent(child) is False:
uowcommit.register_object(
- child, isdelete=True,
- operation="delete", prop=self.prop)
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
for c, m, st_, dct_ in self.mapper.cascade_iterator(
- 'delete',
- child):
- uowcommit.register_object(
- st_, isdelete=True)
+ "delete", child
+ ):
+ uowcommit.register_object(st_, isdelete=True)
def process_deletes(self, uowcommit, states):
secondary_delete = []
@@ -998,21 +1061,23 @@ class ManyToManyDP(DependencyProcessor):
# this history should be cached already, as
# we loaded it in preprocess_deletes
history = uowcommit.get_attribute_history(
- state,
- self.key,
- self._passive_delete_flag)
+ state, self.key, self._passive_delete_flag
+ )
if history:
for child in history.non_added():
- if child is None or \
- (processed is not None and
- (state, child) in processed):
+ if child is None or (
+ processed is not None and (state, child) in processed
+ ):
continue
associationrow = {}
if not self._synchronize(
- state,
- child,
- associationrow,
- False, uowcommit, "delete"):
+ state,
+ child,
+ associationrow,
+ False,
+ uowcommit,
+ "delete",
+ ):
continue
secondary_delete.append(associationrow)
@@ -1021,8 +1086,9 @@ class ManyToManyDP(DependencyProcessor):
if processed is not None:
processed.update(tmp)
- self._run_crud(uowcommit, secondary_insert,
- secondary_update, secondary_delete)
+ self._run_crud(
+ uowcommit, secondary_insert, secondary_update, secondary_delete
+ )
def process_saves(self, uowcommit, states):
secondary_delete = []
@@ -1033,110 +1099,133 @@ class ManyToManyDP(DependencyProcessor):
tmp = set()
for state in states:
- need_cascade_pks = not self.passive_updates and \
- self._pks_changed(uowcommit, state)
+ need_cascade_pks = not self.passive_updates and self._pks_changed(
+ uowcommit, state
+ )
if need_cascade_pks:
passive = attributes.PASSIVE_OFF
else:
passive = attributes.PASSIVE_NO_INITIALIZE
- history = uowcommit.get_attribute_history(state, self.key,
- passive)
+ history = uowcommit.get_attribute_history(state, self.key, passive)
if history:
for child in history.added:
- if (processed is not None and
- (state, child) in processed):
+ if processed is not None and (state, child) in processed:
continue
associationrow = {}
- if not self._synchronize(state,
- child,
- associationrow,
- False, uowcommit, "add"):
+ if not self._synchronize(
+ state, child, associationrow, False, uowcommit, "add"
+ ):
continue
secondary_insert.append(associationrow)
for child in history.deleted:
- if (processed is not None and
- (state, child) in processed):
+ if processed is not None and (state, child) in processed:
continue
associationrow = {}
- if not self._synchronize(state,
- child,
- associationrow,
- False, uowcommit, "delete"):
+ if not self._synchronize(
+ state,
+ child,
+ associationrow,
+ False,
+ uowcommit,
+ "delete",
+ ):
continue
secondary_delete.append(associationrow)
- tmp.update((c, state)
- for c in history.added + history.deleted)
+ tmp.update((c, state) for c in history.added + history.deleted)
if need_cascade_pks:
for child in history.unchanged:
associationrow = {}
- sync.update(state,
- self.parent,
- associationrow,
- "old_",
- self.prop.synchronize_pairs)
- sync.update(child,
- self.mapper,
- associationrow,
- "old_",
- self.prop.secondary_synchronize_pairs)
+ sync.update(
+ state,
+ self.parent,
+ associationrow,
+ "old_",
+ self.prop.synchronize_pairs,
+ )
+ sync.update(
+ child,
+ self.mapper,
+ associationrow,
+ "old_",
+ self.prop.secondary_synchronize_pairs,
+ )
secondary_update.append(associationrow)
if processed is not None:
processed.update(tmp)
- self._run_crud(uowcommit, secondary_insert,
- secondary_update, secondary_delete)
+ self._run_crud(
+ uowcommit, secondary_insert, secondary_update, secondary_delete
+ )
- def _run_crud(self, uowcommit, secondary_insert,
- secondary_update, secondary_delete):
+ def _run_crud(
+ self, uowcommit, secondary_insert, secondary_update, secondary_delete
+ ):
connection = uowcommit.transaction.connection(self.mapper)
if secondary_delete:
associationrow = secondary_delete[0]
- statement = self.secondary.delete(sql.and_(*[
- c == sql.bindparam(c.key, type_=c.type)
- for c in self.secondary.c
- if c.key in associationrow
- ]))
+ statement = self.secondary.delete(
+ sql.and_(
+ *[
+ c == sql.bindparam(c.key, type_=c.type)
+ for c in self.secondary.c
+ if c.key in associationrow
+ ]
+ )
+ )
result = connection.execute(statement, secondary_delete)
- if result.supports_sane_multi_rowcount() and \
- result.rowcount != len(secondary_delete):
+ if result.supports_sane_multi_rowcount() and result.rowcount != len(
+ secondary_delete
+ ):
raise exc.StaleDataError(
"DELETE statement on table '%s' expected to delete "
- "%d row(s); Only %d were matched." %
- (self.secondary.description, len(secondary_delete),
- result.rowcount)
+ "%d row(s); Only %d were matched."
+ % (
+ self.secondary.description,
+ len(secondary_delete),
+ result.rowcount,
+ )
)
if secondary_update:
associationrow = secondary_update[0]
- statement = self.secondary.update(sql.and_(*[
- c == sql.bindparam("old_" + c.key, type_=c.type)
- for c in self.secondary.c
- if c.key in associationrow
- ]))
+ statement = self.secondary.update(
+ sql.and_(
+ *[
+ c == sql.bindparam("old_" + c.key, type_=c.type)
+ for c in self.secondary.c
+ if c.key in associationrow
+ ]
+ )
+ )
result = connection.execute(statement, secondary_update)
- if result.supports_sane_multi_rowcount() and \
- result.rowcount != len(secondary_update):
+ if result.supports_sane_multi_rowcount() and result.rowcount != len(
+ secondary_update
+ ):
raise exc.StaleDataError(
"UPDATE statement on table '%s' expected to update "
- "%d row(s); Only %d were matched." %
- (self.secondary.description, len(secondary_update),
- result.rowcount)
+ "%d row(s); Only %d were matched."
+ % (
+ self.secondary.description,
+ len(secondary_update),
+ result.rowcount,
+ )
)
if secondary_insert:
statement = self.secondary.insert()
connection.execute(statement, secondary_insert)
- def _synchronize(self, state, child, associationrow,
- clearkeys, uowcommit, operation):
+ def _synchronize(
+ self, state, child, associationrow, clearkeys, uowcommit, operation
+ ):
# this checks for None if uselist=True
self._verify_canload(child)
@@ -1150,23 +1239,28 @@ class ManyToManyDP(DependencyProcessor):
if not child.deleted:
util.warn(
"Object of type %s not in session, %s "
- "operation along '%s' won't proceed" %
- (mapperutil.state_class_str(child), operation, self.prop))
+ "operation along '%s' won't proceed"
+ % (mapperutil.state_class_str(child), operation, self.prop)
+ )
return False
- sync.populate_dict(state, self.parent, associationrow,
- self.prop.synchronize_pairs)
- sync.populate_dict(child, self.mapper, associationrow,
- self.prop.secondary_synchronize_pairs)
+ sync.populate_dict(
+ state, self.parent, associationrow, self.prop.synchronize_pairs
+ )
+ sync.populate_dict(
+ child,
+ self.mapper,
+ associationrow,
+ self.prop.secondary_synchronize_pairs,
+ )
return True
def _pks_changed(self, uowcommit, state):
return sync.source_modified(
- uowcommit,
- state,
- self.parent,
- self.prop.synchronize_pairs)
+ uowcommit, state, self.parent, self.prop.synchronize_pairs
+ )
+
_direction_to_processor = {
ONETOMANY: OneToManyDP,
diff --git a/lib/sqlalchemy/orm/deprecated_interfaces.py b/lib/sqlalchemy/orm/deprecated_interfaces.py
index 426288e03..6b51404d0 100644
--- a/lib/sqlalchemy/orm/deprecated_interfaces.py
+++ b/lib/sqlalchemy/orm/deprecated_interfaces.py
@@ -58,23 +58,25 @@ class MapperExtension(object):
@classmethod
def _adapt_instrument_class(cls, self, listener):
- cls._adapt_listener_methods(self, listener, ('instrument_class',))
+ cls._adapt_listener_methods(self, listener, ("instrument_class",))
@classmethod
def _adapt_listener(cls, self, listener):
cls._adapt_listener_methods(
- self, listener,
+ self,
+ listener,
(
- 'init_instance',
- 'init_failed',
- 'reconstruct_instance',
- 'before_insert',
- 'after_insert',
- 'before_update',
- 'after_update',
- 'before_delete',
- 'after_delete'
- ))
+ "init_instance",
+ "init_failed",
+ "reconstruct_instance",
+ "before_insert",
+ "after_insert",
+ "before_update",
+ "after_update",
+ "before_delete",
+ "after_delete",
+ ),
+ )
@classmethod
def _adapt_listener_methods(cls, self, listener, methods):
@@ -84,36 +86,75 @@ class MapperExtension(object):
ls_meth = getattr(listener, meth)
if not util.methods_equivalent(me_meth, ls_meth):
- if meth == 'reconstruct_instance':
+ if meth == "reconstruct_instance":
+
def go(ls_meth):
def reconstruct(instance, ctx):
ls_meth(self, instance)
+
return reconstruct
- event.listen(self.class_manager, 'load',
- go(ls_meth), raw=False, propagate=True)
- elif meth == 'init_instance':
+
+ event.listen(
+ self.class_manager,
+ "load",
+ go(ls_meth),
+ raw=False,
+ propagate=True,
+ )
+ elif meth == "init_instance":
+
def go(ls_meth):
def init_instance(instance, args, kwargs):
- ls_meth(self, self.class_,
- self.class_manager.original_init,
- instance, args, kwargs)
+ ls_meth(
+ self,
+ self.class_,
+ self.class_manager.original_init,
+ instance,
+ args,
+ kwargs,
+ )
+
return init_instance
- event.listen(self.class_manager, 'init',
- go(ls_meth), raw=False, propagate=True)
- elif meth == 'init_failed':
+
+ event.listen(
+ self.class_manager,
+ "init",
+ go(ls_meth),
+ raw=False,
+ propagate=True,
+ )
+ elif meth == "init_failed":
+
def go(ls_meth):
def init_failed(instance, args, kwargs):
util.warn_exception(
- ls_meth, self, self.class_,
+ ls_meth,
+ self,
+ self.class_,
self.class_manager.original_init,
- instance, args, kwargs)
+ instance,
+ args,
+ kwargs,
+ )
return init_failed
- event.listen(self.class_manager, 'init_failure',
- go(ls_meth), raw=False, propagate=True)
+
+ event.listen(
+ self.class_manager,
+ "init_failure",
+ go(ls_meth),
+ raw=False,
+ propagate=True,
+ )
else:
- event.listen(self, "%s" % meth, ls_meth,
- raw=False, retval=True, propagate=True)
+ event.listen(
+ self,
+ "%s" % meth,
+ ls_meth,
+ raw=False,
+ retval=True,
+ propagate=True,
+ )
def instrument_class(self, mapper, class_):
"""Receive a class when the mapper is first constructed, and has
@@ -302,16 +343,16 @@ class SessionExtension(object):
@classmethod
def _adapt_listener(cls, self, listener):
for meth in [
- 'before_commit',
- 'after_commit',
- 'after_rollback',
- 'before_flush',
- 'after_flush',
- 'after_flush_postexec',
- 'after_begin',
- 'after_attach',
- 'after_bulk_update',
- 'after_bulk_delete',
+ "before_commit",
+ "after_commit",
+ "after_rollback",
+ "before_flush",
+ "after_flush",
+ "after_flush_postexec",
+ "after_begin",
+ "after_attach",
+ "after_bulk_update",
+ "after_bulk_delete",
]:
me_meth = getattr(SessionExtension, meth)
ls_meth = getattr(listener, meth)
@@ -450,15 +491,30 @@ class AttributeExtension(object):
@classmethod
def _adapt_listener(cls, self, listener):
- event.listen(self, 'append', listener.append,
- active_history=listener.active_history,
- raw=True, retval=True)
- event.listen(self, 'remove', listener.remove,
- active_history=listener.active_history,
- raw=True, retval=True)
- event.listen(self, 'set', listener.set,
- active_history=listener.active_history,
- raw=True, retval=True)
+ event.listen(
+ self,
+ "append",
+ listener.append,
+ active_history=listener.active_history,
+ raw=True,
+ retval=True,
+ )
+ event.listen(
+ self,
+ "remove",
+ listener.remove,
+ active_history=listener.active_history,
+ raw=True,
+ retval=True,
+ )
+ event.listen(
+ self,
+ "set",
+ listener.set,
+ active_history=listener.active_history,
+ raw=True,
+ retval=True,
+ )
def append(self, state, value, initiator):
"""Receive a collection append event.
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index fefd2d2a1..37517e84c 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -37,9 +37,11 @@ class DescriptorProperty(MapperProperty):
def __init__(self, key):
self.key = key
- if hasattr(prop, 'get_history'):
- def get_history(self, state, dict_,
- passive=attributes.PASSIVE_OFF):
+ if hasattr(prop, "get_history"):
+
+ def get_history(
+ self, state, dict_, passive=attributes.PASSIVE_OFF
+ ):
return prop.get_history(state, dict_, passive)
if self.descriptor is None:
@@ -48,6 +50,7 @@ class DescriptorProperty(MapperProperty):
self.descriptor = desc
if self.descriptor is None:
+
def fset(obj, value):
setattr(obj, self.name, value)
@@ -57,21 +60,16 @@ class DescriptorProperty(MapperProperty):
def fget(obj):
return getattr(obj, self.name)
- self.descriptor = property(
- fget=fget,
- fset=fset,
- fdel=fdel,
- )
+ self.descriptor = property(fget=fget, fset=fset, fdel=fdel)
- proxy_attr = attributes.create_proxied_attribute(
- self.descriptor)(
- self.parent.class_,
- self.key,
- self.descriptor,
- lambda: self._comparator_factory(mapper),
- doc=self.doc,
- original_property=self
- )
+ proxy_attr = attributes.create_proxied_attribute(self.descriptor)(
+ self.parent.class_,
+ self.key,
+ self.descriptor,
+ lambda: self._comparator_factory(mapper),
+ doc=self.doc,
+ original_property=self,
+ )
proxy_attr.impl = _ProxyImpl(self.key)
mapper.class_manager.instrument_attribute(self.key, proxy_attr)
@@ -149,13 +147,14 @@ class CompositeProperty(DescriptorProperty):
self.attrs = attrs
self.composite_class = class_
- self.active_history = kwargs.get('active_history', False)
- self.deferred = kwargs.get('deferred', False)
- self.group = kwargs.get('group', None)
- self.comparator_factory = kwargs.pop('comparator_factory',
- self.__class__.Comparator)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
+ self.active_history = kwargs.get("active_history", False)
+ self.deferred = kwargs.get("deferred", False)
+ self.group = kwargs.get("group", None)
+ self.comparator_factory = kwargs.pop(
+ "comparator_factory", self.__class__.Comparator
+ )
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
util.set_creation_order(self)
self._create_descriptor()
@@ -186,8 +185,7 @@ class CompositeProperty(DescriptorProperty):
# attributes, retrieve their values. This
# ensures they all load.
values = [
- getattr(instance, key)
- for key in self._attribute_keys
+ getattr(instance, key) for key in self._attribute_keys
]
# current expected behavior here is that the composite is
@@ -196,8 +194,7 @@ class CompositeProperty(DescriptorProperty):
# if the composite were created unconditionally,
# but that would be a behavioral change.
if self.key not in dict_ and (
- state.key is not None or
- not _none_set.issuperset(values)
+ state.key is not None or not _none_set.issuperset(values)
):
dict_[self.key] = self.composite_class(*values)
state.manager.dispatch.refresh(state, None, [self.key])
@@ -217,8 +214,8 @@ class CompositeProperty(DescriptorProperty):
setattr(instance, key, None)
else:
for key, value in zip(
- self._attribute_keys,
- value.__composite_values__()):
+ self._attribute_keys, value.__composite_values__()
+ ):
setattr(instance, key, value)
def fdel(instance):
@@ -234,18 +231,14 @@ class CompositeProperty(DescriptorProperty):
@util.memoized_property
def _comparable_elements(self):
- return [
- getattr(self.parent.class_, prop.key)
- for prop in self.props
- ]
+ return [getattr(self.parent.class_, prop.key) for prop in self.props]
@util.memoized_property
def props(self):
props = []
for attr in self.attrs:
if isinstance(attr, str):
- prop = self.parent.get_property(
- attr, _configure_mappers=False)
+ prop = self.parent.get_property(attr, _configure_mappers=False)
elif isinstance(attr, schema.Column):
prop = self.parent._columntoproperty[attr]
elif isinstance(attr, attributes.InstrumentedAttribute):
@@ -254,7 +247,8 @@ class CompositeProperty(DescriptorProperty):
raise sa_exc.ArgumentError(
"Composite expects Column objects or mapped "
"attributes/attribute names as arguments, got: %r"
- % (attr,))
+ % (attr,)
+ )
props.append(prop)
return props
@@ -271,9 +265,7 @@ class CompositeProperty(DescriptorProperty):
prop.active_history = self.active_history
if self.deferred:
prop.deferred = self.deferred
- prop.strategy_key = (
- ("deferred", True),
- ("instrument", True))
+ prop.strategy_key = (("deferred", True), ("instrument", True))
prop.group = self.group
def _setup_event_handlers(self):
@@ -299,8 +291,7 @@ class CompositeProperty(DescriptorProperty):
return
dict_[self.key] = self.composite_class(
- *[state.dict[key] for key in
- self._attribute_keys]
+ *[state.dict[key] for key in self._attribute_keys]
)
def expire_handler(state, keys):
@@ -317,24 +308,27 @@ class CompositeProperty(DescriptorProperty):
state.dict.pop(self.key, None)
- event.listen(self.parent, 'after_insert',
- insert_update_handler, raw=True)
- event.listen(self.parent, 'after_update',
- insert_update_handler, raw=True)
- event.listen(self.parent, 'load',
- load_handler, raw=True, propagate=True)
- event.listen(self.parent, 'refresh',
- refresh_handler, raw=True, propagate=True)
- event.listen(self.parent, 'expire',
- expire_handler, raw=True, propagate=True)
+ event.listen(
+ self.parent, "after_insert", insert_update_handler, raw=True
+ )
+ event.listen(
+ self.parent, "after_update", insert_update_handler, raw=True
+ )
+ event.listen(
+ self.parent, "load", load_handler, raw=True, propagate=True
+ )
+ event.listen(
+ self.parent, "refresh", refresh_handler, raw=True, propagate=True
+ )
+ event.listen(
+ self.parent, "expire", expire_handler, raw=True, propagate=True
+ )
# TODO: need a deserialize hook here
@util.memoized_property
def _attribute_keys(self):
- return [
- prop.key for prop in self.props
- ]
+ return [prop.key for prop in self.props]
def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
"""Provided for userland code that uses attributes.get_history()."""
@@ -363,12 +357,10 @@ class CompositeProperty(DescriptorProperty):
return attributes.History(
[self.composite_class(*added)],
(),
- [self.composite_class(*deleted)]
+ [self.composite_class(*deleted)],
)
else:
- return attributes.History(
- (), [self.composite_class(*added)], ()
- )
+ return attributes.History((), [self.composite_class(*added)], ())
def _comparator_factory(self, mapper):
return self.comparator_factory(self, mapper)
@@ -377,12 +369,15 @@ class CompositeProperty(DescriptorProperty):
def __init__(self, property, expr):
self.property = property
super(CompositeProperty.CompositeBundle, self).__init__(
- property.key, *expr)
+ property.key, *expr
+ )
def create_row_processor(self, query, procs, labels):
def proc(row):
return self.property.composite_class(
- *[proc(row) for proc in procs])
+ *[proc(row) for proc in procs]
+ )
+
return proc
class Comparator(PropComparator):
@@ -412,11 +407,13 @@ class CompositeProperty(DescriptorProperty):
def __clause_element__(self):
return expression.ClauseList(
- group=False, *self._comparable_elements)
+ group=False, *self._comparable_elements
+ )
def _query_clause_element(self):
return CompositeProperty.CompositeBundle(
- self.prop, self.__clause_element__())
+ self.prop, self.__clause_element__()
+ )
def _bulk_update_tuples(self, value):
if value is None:
@@ -425,22 +422,18 @@ class CompositeProperty(DescriptorProperty):
values = value.__composite_values__()
else:
raise sa_exc.ArgumentError(
- "Can't UPDATE composite attribute %s to %r" %
- (self.prop, value))
+ "Can't UPDATE composite attribute %s to %r"
+ % (self.prop, value)
+ )
- return zip(
- self._comparable_elements,
- values
- )
+ return zip(self._comparable_elements, values)
@util.memoized_property
def _comparable_elements(self):
if self._adapt_to_entity:
return [
- getattr(
- self._adapt_to_entity.entity,
- prop.key
- ) for prop in self.prop._comparable_elements
+ getattr(self._adapt_to_entity.entity, prop.key)
+ for prop in self.prop._comparable_elements
]
else:
return self.prop._comparable_elements
@@ -451,8 +444,7 @@ class CompositeProperty(DescriptorProperty):
else:
values = other.__composite_values__()
comparisons = [
- a == b
- for a, b in zip(self.prop._comparable_elements, values)
+ a == b for a, b in zip(self.prop._comparable_elements, values)
]
if self._adapt_to_entity:
comparisons = [self.adapter(x) for x in comparisons]
@@ -495,14 +487,16 @@ class ConcreteInheritedProperty(DescriptorProperty):
def __init__(self):
super(ConcreteInheritedProperty, self).__init__()
+
def warn():
- raise AttributeError("Concrete %s does not implement "
- "attribute %r at the instance level. Add "
- "this property explicitly to %s." %
- (self.parent, self.key, self.parent))
+ raise AttributeError(
+ "Concrete %s does not implement "
+ "attribute %r at the instance level. Add "
+ "this property explicitly to %s."
+ % (self.parent, self.key, self.parent)
+ )
class NoninheritedConcreteProp(object):
-
def __set__(s, obj, value):
warn()
@@ -513,15 +507,21 @@ class ConcreteInheritedProperty(DescriptorProperty):
if obj is None:
return self.descriptor
warn()
+
self.descriptor = NoninheritedConcreteProp()
@util.langhelpers.dependency_for("sqlalchemy.orm.properties", add_to_all=True)
class SynonymProperty(DescriptorProperty):
-
- def __init__(self, name, map_column=None,
- descriptor=None, comparator_factory=None,
- doc=None, info=None):
+ def __init__(
+ self,
+ name,
+ map_column=None,
+ descriptor=None,
+ comparator_factory=None,
+ doc=None,
+ info=None,
+ ):
"""Denote an attribute name as a synonym to a mapped property,
in that the attribute will mirror the value and expression behavior
of another attribute.
@@ -639,15 +639,13 @@ class SynonymProperty(DescriptorProperty):
@util.memoized_property
def _proxied_property(self):
attr = getattr(self.parent.class_, self.name)
- if not hasattr(attr, 'property') or not \
- isinstance(attr.property, MapperProperty):
+ if not hasattr(attr, "property") or not isinstance(
+ attr.property, MapperProperty
+ ):
raise sa_exc.InvalidRequestError(
"""synonym() attribute "%s.%s" only supports """
- """ORM mapped attributes, got %r""" % (
- self.parent.class_.__name__,
- self.name,
- attr
- )
+ """ORM mapped attributes, got %r"""
+ % (self.parent.class_.__name__, self.name, attr)
)
return attr.property
@@ -671,23 +669,23 @@ class SynonymProperty(DescriptorProperty):
raise sa_exc.ArgumentError(
"Can't compile synonym '%s': no column on table "
"'%s' named '%s'"
- % (self.name, parent.mapped_table.description, self.key))
- elif parent.mapped_table.c[self.key] in \
- parent._columntoproperty and \
- parent._columntoproperty[
- parent.mapped_table.c[self.key]
- ].key == self.name:
+ % (self.name, parent.mapped_table.description, self.key)
+ )
+ elif (
+ parent.mapped_table.c[self.key] in parent._columntoproperty
+ and parent._columntoproperty[
+ parent.mapped_table.c[self.key]
+ ].key
+ == self.name
+ ):
raise sa_exc.ArgumentError(
"Can't call map_column=True for synonym %r=%r, "
"a ColumnProperty already exists keyed to the name "
- "%r for column %r" %
- (self.key, self.name, self.name, self.key)
+ "%r for column %r"
+ % (self.key, self.name, self.name, self.key)
)
p = properties.ColumnProperty(parent.mapped_table.c[self.key])
- parent._configure_property(
- self.name, p,
- init=init,
- setparent=True)
+ parent._configure_property(self.name, p, init=init, setparent=True)
p._mapped_by_synonym = self.key
self.parent = parent
@@ -698,7 +696,8 @@ class ComparableProperty(DescriptorProperty):
"""Instruments a Python property for use in query expressions."""
def __init__(
- self, comparator_factory, descriptor=None, doc=None, info=None):
+ self, comparator_factory, descriptor=None, doc=None, info=None
+ ):
"""Provides a method of applying a :class:`.PropComparator`
to any Python descriptor attribute.
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
index 087e7dcc6..e5c6b80b6 100644
--- a/lib/sqlalchemy/orm/dynamic.py
+++ b/lib/sqlalchemy/orm/dynamic.py
@@ -15,8 +15,13 @@ basic add/delete mutation.
from .. import log, util, exc
from ..sql import operators
from . import (
- attributes, object_session, util as orm_util, strategies,
- object_mapper, exc as orm_exc, properties
+ attributes,
+ object_session,
+ util as orm_util,
+ strategies,
+ object_mapper,
+ exc as orm_exc,
+ properties,
)
from .query import Query
@@ -30,7 +35,8 @@ class DynaLoader(strategies.AbstractRelationshipLoader):
raise exc.InvalidRequestError(
"On relationship %s, 'dynamic' loaders cannot be used with "
"many-to-one/one-to-one relationships and/or "
- "uselist=False." % self.parent_property)
+ "uselist=False." % self.parent_property
+ )
strategies._register_attribute(
self.parent_property,
mapper,
@@ -49,11 +55,20 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
collection = False
dynamic = True
- def __init__(self, class_, key, typecallable,
- dispatch,
- target_mapper, order_by, query_class=None, **kw):
- super(DynamicAttributeImpl, self).\
- __init__(class_, key, typecallable, dispatch, **kw)
+ def __init__(
+ self,
+ class_,
+ key,
+ typecallable,
+ dispatch,
+ target_mapper,
+ order_by,
+ query_class=None,
+ **kw
+ ):
+ super(DynamicAttributeImpl, self).__init__(
+ class_, key, typecallable, dispatch, **kw
+ )
self.target_mapper = target_mapper
self.order_by = order_by
if not query_class:
@@ -66,15 +81,20 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
if not passive & attributes.SQL_OK:
return self._get_collection_history(
- state, attributes.PASSIVE_NO_INITIALIZE).added_items
+ state, attributes.PASSIVE_NO_INITIALIZE
+ ).added_items
else:
return self.query_class(self, state)
- def get_collection(self, state, dict_, user_data=None,
- passive=attributes.PASSIVE_NO_INITIALIZE):
+ def get_collection(
+ self,
+ state,
+ dict_,
+ user_data=None,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ ):
if not passive & attributes.SQL_OK:
- return self._get_collection_history(state,
- passive).added_items
+ return self._get_collection_history(state, passive).added_items
else:
history = self._get_collection_history(state, passive)
return history.added_plus_unchanged
@@ -87,8 +107,9 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
def _remove_token(self):
return attributes.Event(self, attributes.OP_REMOVE)
- def fire_append_event(self, state, dict_, value, initiator,
- collection_history=None):
+ def fire_append_event(
+ self, state, dict_, value, initiator, collection_history=None
+ ):
if collection_history is None:
collection_history = self._modified_event(state, dict_)
@@ -100,8 +121,9 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), state, True)
- def fire_remove_event(self, state, dict_, value, initiator,
- collection_history=None):
+ def fire_remove_event(
+ self, state, dict_, value, initiator, collection_history=None
+ ):
if collection_history is None:
collection_history = self._modified_event(state, dict_)
@@ -118,18 +140,24 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
if self.key not in state.committed_state:
state.committed_state[self.key] = CollectionHistory(self, state)
- state._modified_event(dict_,
- self,
- attributes.NEVER_SET)
+ state._modified_event(dict_, self, attributes.NEVER_SET)
# this is a hack to allow the fixtures.ComparableEntity fixture
# to work
dict_[self.key] = True
return state.committed_state[self.key]
- def set(self, state, dict_, value, initiator=None,
- passive=attributes.PASSIVE_OFF,
- check_old=None, pop=False, _adapt=True):
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator=None,
+ passive=attributes.PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ _adapt=True,
+ ):
if initiator and initiator.parent_token is self.parent_token:
return
@@ -146,7 +174,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
old_collection = collection_history.added_items
else:
old_collection = old_collection.union(
- collection_history.added_items)
+ collection_history.added_items
+ )
idset = util.IdentitySet
constants = old_collection.intersection(new_values)
@@ -155,33 +184,40 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
for member in new_values:
if member in additions:
- self.fire_append_event(state, dict_, member, None,
- collection_history=collection_history)
+ self.fire_append_event(
+ state,
+ dict_,
+ member,
+ None,
+ collection_history=collection_history,
+ )
for member in removals:
- self.fire_remove_event(state, dict_, member, None,
- collection_history=collection_history)
+ self.fire_remove_event(
+ state,
+ dict_,
+ member,
+ None,
+ collection_history=collection_history,
+ )
def delete(self, *args, **kwargs):
raise NotImplementedError()
def set_committed_value(self, state, dict_, value):
- raise NotImplementedError("Dynamic attributes don't support "
- "collection population.")
+ raise NotImplementedError(
+ "Dynamic attributes don't support " "collection population."
+ )
def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
c = self._get_collection_history(state, passive)
return c.as_history()
- def get_all_pending(self, state, dict_,
- passive=attributes.PASSIVE_NO_INITIALIZE):
- c = self._get_collection_history(
- state, passive)
- return [
- (attributes.instance_state(x), x)
- for x in
- c.all_items
- ]
+ def get_all_pending(
+ self, state, dict_, passive=attributes.PASSIVE_NO_INITIALIZE
+ ):
+ c = self._get_collection_history(state, passive)
+ return [(attributes.instance_state(x), x) for x in c.all_items]
def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF):
if self.key in state.committed_state:
@@ -194,18 +230,21 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
else:
return c
- def append(self, state, dict_, value, initiator,
- passive=attributes.PASSIVE_OFF):
+ def append(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
if initiator is not self:
self.fire_append_event(state, dict_, value, initiator)
- def remove(self, state, dict_, value, initiator,
- passive=attributes.PASSIVE_OFF):
+ def remove(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
if initiator is not self:
self.fire_remove_event(state, dict_, value, initiator)
- def pop(self, state, dict_, value, initiator,
- passive=attributes.PASSIVE_OFF):
+ def pop(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
self.remove(state, dict_, value, initiator, passive=passive)
@@ -229,30 +268,36 @@ class AppenderMixin(object):
# doesn't fail, and secondary is then in _from_obj[1].
self._from_obj = (prop.mapper.selectable, prop.secondary)
- self._criterion = prop._with_parent(
- instance,
- alias_secondary=False)
+ self._criterion = prop._with_parent(instance, alias_secondary=False)
if self.attr.order_by:
self._order_by = self.attr.order_by
def session(self):
sess = object_session(self.instance)
- if sess is not None and self.autoflush and sess.autoflush \
- and self.instance in sess:
+ if (
+ sess is not None
+ and self.autoflush
+ and sess.autoflush
+ and self.instance in sess
+ ):
sess.flush()
if not orm_util.has_identity(self.instance):
return None
else:
return sess
+
session = property(session, lambda s, x: None)
def __iter__(self):
sess = self.session
if sess is None:
- return iter(self.attr._get_collection_history(
- attributes.instance_state(self.instance),
- attributes.PASSIVE_NO_INITIALIZE).added_items)
+ return iter(
+ self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).added_items
+ )
else:
return iter(self._clone(sess))
@@ -261,16 +306,20 @@ class AppenderMixin(object):
if sess is None:
return self.attr._get_collection_history(
attributes.instance_state(self.instance),
- attributes.PASSIVE_NO_INITIALIZE).indexed(index)
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).indexed(index)
else:
return self._clone(sess).__getitem__(index)
def count(self):
sess = self.session
if sess is None:
- return len(self.attr._get_collection_history(
- attributes.instance_state(self.instance),
- attributes.PASSIVE_NO_INITIALIZE).added_items)
+ return len(
+ self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).added_items
+ )
else:
return self._clone(sess).count()
@@ -285,8 +334,9 @@ class AppenderMixin(object):
raise orm_exc.DetachedInstanceError(
"Parent instance %s is not bound to a Session, and no "
"contextual session is established; lazy load operation "
- "of attribute '%s' cannot proceed" % (
- orm_util.instance_str(instance), self.attr.key))
+ "of attribute '%s' cannot proceed"
+ % (orm_util.instance_str(instance), self.attr.key)
+ )
if self.query_class:
query = self.query_class(self.attr.target_mapper, session=sess)
@@ -303,17 +353,26 @@ class AppenderMixin(object):
for item in iterator:
self.attr.append(
attributes.instance_state(self.instance),
- attributes.instance_dict(self.instance), item, None)
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
def append(self, item):
self.attr.append(
attributes.instance_state(self.instance),
- attributes.instance_dict(self.instance), item, None)
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
def remove(self, item):
self.attr.remove(
attributes.instance_state(self.instance),
- attributes.instance_dict(self.instance), item, None)
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
class AppenderQuery(AppenderMixin, Query):
@@ -322,8 +381,8 @@ class AppenderQuery(AppenderMixin, Query):
def mixin_user_query(cls):
"""Return a new class with AppenderQuery functionality layered over."""
- name = 'Appender' + cls.__name__
- return type(name, (AppenderMixin, cls), {'query_class': cls})
+ name = "Appender" + cls.__name__
+ return type(name, (AppenderMixin, cls), {"query_class": cls})
class CollectionHistory(object):
@@ -348,8 +407,11 @@ class CollectionHistory(object):
@property
def all_items(self):
- return list(self.added_items.union(
- self.unchanged_items).union(self.deleted_items))
+ return list(
+ self.added_items.union(self.unchanged_items).union(
+ self.deleted_items
+ )
+ )
def as_history(self):
if self._reconcile_collection:
@@ -357,14 +419,12 @@ class CollectionHistory(object):
deleted = self.deleted_items.intersection(self.unchanged_items)
unchanged = self.unchanged_items.difference(deleted)
else:
- added, unchanged, deleted = self.added_items,\
- self.unchanged_items,\
- self.deleted_items
- return attributes.History(
- list(added),
- list(unchanged),
- list(deleted),
- )
+ added, unchanged, deleted = (
+ self.added_items,
+ self.unchanged_items,
+ self.deleted_items,
+ )
+ return attributes.History(list(added), list(unchanged), list(deleted))
def indexed(self, index):
return list(self.added_items)[index]
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
index 4abf08ab1..ac031d84f 100644
--- a/lib/sqlalchemy/orm/evaluator.py
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -14,17 +14,40 @@ from .. import util
class UnevaluatableError(Exception):
pass
-_straight_ops = set(getattr(operators, op)
- for op in ('add', 'mul', 'sub',
- 'div',
- 'mod', 'truediv',
- 'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
-
-_notimplemented_ops = set(getattr(operators, op)
- for op in ('like_op', 'notlike_op', 'ilike_op',
- 'notilike_op', 'between_op', 'in_op',
- 'notin_op', 'endswith_op', 'concat_op'))
+_straight_ops = set(
+ getattr(operators, op)
+ for op in (
+ "add",
+ "mul",
+ "sub",
+ "div",
+ "mod",
+ "truediv",
+ "lt",
+ "le",
+ "ne",
+ "gt",
+ "ge",
+ "eq",
+ )
+)
+
+
+_notimplemented_ops = set(
+ getattr(operators, op)
+ for op in (
+ "like_op",
+ "notlike_op",
+ "ilike_op",
+ "notilike_op",
+ "between_op",
+ "in_op",
+ "notin_op",
+ "endswith_op",
+ "concat_op",
+ )
+)
class EvaluatorCompiler(object):
@@ -35,7 +58,8 @@ class EvaluatorCompiler(object):
meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
if not meth:
raise UnevaluatableError(
- "Cannot evaluate %s" % type(clause).__name__)
+ "Cannot evaluate %s" % type(clause).__name__
+ )
return meth(clause)
def visit_grouping(self, clause):
@@ -51,28 +75,30 @@ class EvaluatorCompiler(object):
return lambda obj: True
def visit_column(self, clause):
- if 'parentmapper' in clause._annotations:
- parentmapper = clause._annotations['parentmapper']
+ if "parentmapper" in clause._annotations:
+ parentmapper = clause._annotations["parentmapper"]
if self.target_cls and not issubclass(
- self.target_cls, parentmapper.class_):
+ self.target_cls, parentmapper.class_
+ ):
raise UnevaluatableError(
- "Can't evaluate criteria against alternate class %s" %
- parentmapper.class_
+ "Can't evaluate criteria against alternate class %s"
+ % parentmapper.class_
)
key = parentmapper._columntoproperty[clause].key
else:
key = clause.key
- if self.target_cls and \
- key in inspect(self.target_cls).column_attrs:
+ if (
+ self.target_cls
+ and key in inspect(self.target_cls).column_attrs
+ ):
util.warn(
"Evaluating non-mapped column expression '%s' onto "
"ORM instances; this is a deprecated use case. Please "
"make use of the actual mapped columns in ORM-evaluated "
- "UPDATE / DELETE expressions." % clause)
- else:
- raise UnevaluatableError(
- "Cannot evaluate column: %s" % clause
+ "UPDATE / DELETE expressions." % clause
)
+ else:
+ raise UnevaluatableError("Cannot evaluate column: %s" % clause)
get_corresponding_attr = operator.attrgetter(key)
return lambda obj: get_corresponding_attr(obj)
@@ -80,6 +106,7 @@ class EvaluatorCompiler(object):
def visit_clauselist(self, clause):
evaluators = list(map(self.process, clause.clauses))
if clause.operator is operators.or_:
+
def evaluate(obj):
has_null = False
for sub_evaluate in evaluators:
@@ -90,7 +117,9 @@ class EvaluatorCompiler(object):
if has_null:
return None
return False
+
elif clause.operator is operators.and_:
+
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
@@ -99,48 +128,60 @@ class EvaluatorCompiler(object):
return None
return False
return True
+
else:
raise UnevaluatableError(
- "Cannot evaluate clauselist with operator %s" %
- clause.operator)
+ "Cannot evaluate clauselist with operator %s" % clause.operator
+ )
return evaluate
def visit_binary(self, clause):
- eval_left, eval_right = list(map(self.process,
- [clause.left, clause.right]))
+ eval_left, eval_right = list(
+ map(self.process, [clause.left, clause.right])
+ )
operator = clause.operator
if operator is operators.is_:
+
def evaluate(obj):
return eval_left(obj) == eval_right(obj)
+
elif operator is operators.isnot:
+
def evaluate(obj):
return eval_left(obj) != eval_right(obj)
+
elif operator in _straight_ops:
+
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is None or right_val is None:
return None
return operator(eval_left(obj), eval_right(obj))
+
else:
raise UnevaluatableError(
- "Cannot evaluate %s with operator %s" %
- (type(clause).__name__, clause.operator))
+ "Cannot evaluate %s with operator %s"
+ % (type(clause).__name__, clause.operator)
+ )
return evaluate
def visit_unary(self, clause):
eval_inner = self.process(clause.element)
if clause.operator is operators.inv:
+
def evaluate(obj):
value = eval_inner(obj)
if value is None:
return None
return not value
+
return evaluate
raise UnevaluatableError(
- "Cannot evaluate %s with operator %s" %
- (type(clause).__name__, clause.operator))
+ "Cannot evaluate %s with operator %s"
+ % (type(clause).__name__, clause.operator)
+ )
def visit_bindparam(self, clause):
if clause.callable:
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index c414f548e..c2a2d15ee 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -20,6 +20,7 @@ from .attributes import QueryableAttribute
from .query import Query
from sqlalchemy.util.compat import inspect_getargspec
+
class InstrumentationEvents(event.Events):
"""Events related to class instrumentation events.
@@ -61,9 +62,11 @@ class InstrumentationEvents(event.Events):
@classmethod
def _listen(cls, event_key, propagate=True, **kw):
- target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, \
- event_key._listen_fn
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
def listen(target_cls, *arg):
listen_cls = target()
@@ -74,16 +77,20 @@ class InstrumentationEvents(event.Events):
def remove(ref):
key = event.registry._EventKey(
- None, identifier, listen,
- instrumentation._instrumentation_factory)
- getattr(instrumentation._instrumentation_factory.dispatch,
- identifier).remove(key)
+ None,
+ identifier,
+ listen,
+ instrumentation._instrumentation_factory,
+ )
+ getattr(
+ instrumentation._instrumentation_factory.dispatch, identifier
+ ).remove(key)
target = weakref.ref(target.class_, remove)
- event_key.\
- with_dispatch_target(instrumentation._instrumentation_factory).\
- with_wrapper(listen).base_listen(**kw)
+ event_key.with_dispatch_target(
+ instrumentation._instrumentation_factory
+ ).with_wrapper(listen).base_listen(**kw)
@classmethod
def _clear(cls):
@@ -193,21 +200,24 @@ class InstanceEvents(event.Events):
@classmethod
def _listen(cls, event_key, raw=False, propagate=False, **kw):
- target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, \
- event_key._listen_fn
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
if not raw:
+
def wrap(state, *arg, **kw):
return fn(state.obj(), *arg, **kw)
+
event_key = event_key.with_wrapper(wrap)
event_key.base_listen(propagate=propagate, **kw)
if propagate:
for mgr in target.subclass_managers(True):
- event_key.with_dispatch_target(mgr).base_listen(
- propagate=True)
+ event_key.with_dispatch_target(mgr).base_listen(propagate=True)
@classmethod
def _clear(cls):
@@ -438,10 +448,13 @@ class _EventsHold(event.RefCollection):
@classmethod
def _listen(
- cls, event_key, raw=False, propagate=False,
- retval=False, **kw):
- target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, event_key.fn
+ cls, event_key, raw=False, propagate=False, retval=False, **kw
+ ):
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key.fn,
+ )
if target.class_ in target.all_holds:
collection = target.all_holds[target.class_]
@@ -460,12 +473,16 @@ class _EventsHold(event.RefCollection):
if subject is not None:
# we are already going through __subclasses__()
# so leave generic propagate flag False
- event_key.with_dispatch_target(subject).\
- listen(raw=raw, propagate=False, retval=retval, **kw)
+ event_key.with_dispatch_target(subject).listen(
+ raw=raw, propagate=False, retval=retval, **kw
+ )
def remove(self, event_key):
- target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, event_key.fn
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key.fn,
+ )
if isinstance(target, _EventsHold):
collection = target.all_holds[target.class_]
@@ -483,8 +500,9 @@ class _EventsHold(event.RefCollection):
# populate(), we rely upon _EventsHold for all event
# assignment, instead of using the generic propagate
# flag.
- event_key.with_dispatch_target(subject).\
- listen(raw=raw, propagate=False, retval=retval)
+ event_key.with_dispatch_target(subject).listen(
+ raw=raw, propagate=False, retval=retval
+ )
class _InstanceEventsHold(_EventsHold):
@@ -594,24 +612,31 @@ class MapperEvents(event.Events):
@classmethod
def _listen(
- cls, event_key, raw=False, retval=False, propagate=False, **kw):
- target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, \
- event_key._listen_fn
-
- if identifier in ("before_configured", "after_configured") and \
- target is not mapperlib.Mapper:
+ cls, event_key, raw=False, retval=False, propagate=False, **kw
+ ):
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
+
+ if (
+ identifier in ("before_configured", "after_configured")
+ and target is not mapperlib.Mapper
+ ):
util.warn(
"'before_configured' and 'after_configured' ORM events "
"only invoke with the mapper() function or Mapper class "
- "as the target.")
+ "as the target."
+ )
if not raw or not retval:
if not raw:
meth = getattr(cls, identifier)
try:
- target_index = \
- inspect_getargspec(meth)[0].index('target') - 1
+ target_index = (
+ inspect_getargspec(meth)[0].index("target") - 1
+ )
except ValueError:
target_index = None
@@ -624,12 +649,14 @@ class MapperEvents(event.Events):
return interfaces.EXT_CONTINUE
else:
return fn(*arg, **kw)
+
event_key = event_key.with_wrapper(wrap)
if propagate:
for mapper in target.self_and_descendants:
event_key.with_dispatch_target(mapper).base_listen(
- propagate=True, **kw)
+ propagate=True, **kw
+ )
else:
event_key.base_listen(**kw)
@@ -1219,15 +1246,14 @@ class SessionEvents(event.Events):
if isinstance(target, scoped_session):
target = target.session_factory
- if not isinstance(target, sessionmaker) and \
- (
- not isinstance(target, type) or
- not issubclass(target, Session)
+ if not isinstance(target, sessionmaker) and (
+ not isinstance(target, type) or not issubclass(target, Session)
):
raise exc.ArgumentError(
"Session event listen on a scoped_session "
"requires that its creation callable "
- "is associated with the Session class.")
+ "is associated with the Session class."
+ )
if isinstance(target, sessionmaker):
return target.class_
@@ -1561,13 +1587,16 @@ class SessionEvents(event.Events):
"""
- @event._legacy_signature("0.9",
- ["session", "query", "query_context", "result"],
- lambda update_context: (
- update_context.session,
- update_context.query,
- update_context.context,
- update_context.result))
+ @event._legacy_signature(
+ "0.9",
+ ["session", "query", "query_context", "result"],
+ lambda update_context: (
+ update_context.session,
+ update_context.query,
+ update_context.context,
+ update_context.result,
+ ),
+ )
def after_bulk_update(self, update_context):
"""Execute after a bulk update operation to the session.
@@ -1587,13 +1616,16 @@ class SessionEvents(event.Events):
"""
- @event._legacy_signature("0.9",
- ["session", "query", "query_context", "result"],
- lambda delete_context: (
- delete_context.session,
- delete_context.query,
- delete_context.context,
- delete_context.result))
+ @event._legacy_signature(
+ "0.9",
+ ["session", "query", "query_context", "result"],
+ lambda delete_context: (
+ delete_context.session,
+ delete_context.query,
+ delete_context.context,
+ delete_context.result,
+ ),
+ )
def after_bulk_delete(self, delete_context):
"""Execute after a bulk delete operation to the session.
@@ -1927,18 +1959,26 @@ class AttributeEvents(event.Events):
return target
@classmethod
- def _listen(cls, event_key, active_history=False,
- raw=False, retval=False,
- propagate=False):
-
- target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, \
- event_key._listen_fn
+ def _listen(
+ cls,
+ event_key,
+ active_history=False,
+ raw=False,
+ retval=False,
+ propagate=False,
+ ):
+
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
if active_history:
target.dispatch._active_history = True
if not raw or not retval:
+
def wrap(target, *arg):
if not raw:
target = target.obj()
@@ -1951,6 +1991,7 @@ class AttributeEvents(event.Events):
return value
else:
return fn(target, *arg)
+
event_key = event_key.with_wrapper(wrap)
event_key.base_listen(propagate=propagate)
@@ -1959,8 +2000,9 @@ class AttributeEvents(event.Events):
manager = instrumentation.manager_of_class(target.class_)
for mgr in manager.subclass_managers(True):
- event_key.with_dispatch_target(
- mgr[target.key]).base_listen(propagate=True)
+ event_key.with_dispatch_target(mgr[target.key]).base_listen(
+ propagate=True
+ )
def append(self, target, value, initiator):
"""Receive a collection append event.
@@ -2315,11 +2357,11 @@ class QueryEvents(event.Events):
"""
@classmethod
- def _listen(
- cls, event_key, retval=False, **kw):
+ def _listen(cls, event_key, retval=False, **kw):
fn = event_key._listen_fn
if not retval:
+
def wrap(*arg, **kw):
if not retval:
query = arg[0]
@@ -2327,6 +2369,7 @@ class QueryEvents(event.Events):
return query
else:
return fn(*arg, **kw)
+
event_key = event_key.with_wrapper(wrap)
event_key.base_listen(**kw)
diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py
index eb4baa08d..f0aa02e99 100644
--- a/lib/sqlalchemy/orm/exc.py
+++ b/lib/sqlalchemy/orm/exc.py
@@ -38,6 +38,7 @@ class StaleDataError(sa_exc.SQLAlchemyError):
"""
+
ConcurrentModificationError = StaleDataError
@@ -72,16 +73,19 @@ class UnmappedInstanceError(UnmappedError):
try:
base.class_mapper(type(obj))
name = _safe_cls_name(type(obj))
- msg = ("Class %r is mapped, but this instance lacks "
- "instrumentation. This occurs when the instance "
- "is created before sqlalchemy.orm.mapper(%s) "
- "was called." % (name, name))
+ msg = (
+ "Class %r is mapped, but this instance lacks "
+ "instrumentation. This occurs when the instance "
+ "is created before sqlalchemy.orm.mapper(%s) "
+ "was called." % (name, name)
+ )
except UnmappedClassError:
msg = _default_unmapped(type(obj))
if isinstance(obj, type):
msg += (
- '; was a class (%s) supplied where an instance was '
- 'required?' % _safe_cls_name(obj))
+ "; was a class (%s) supplied where an instance was "
+ "required?" % _safe_cls_name(obj)
+ )
UnmappedError.__init__(self, msg)
def __reduce__(self):
@@ -119,11 +123,14 @@ class ObjectDeletedError(sa_exc.InvalidRequestError):
object.
"""
+
@util.dependencies("sqlalchemy.orm.base")
def __init__(self, base, state, msg=None):
if not msg:
- msg = "Instance '%s' has been deleted, or its "\
+ msg = (
+ "Instance '%s' has been deleted, or its "
"row is otherwise not present." % base.state_str(state)
+ )
sa_exc.InvalidRequestError.__init__(self, msg)
@@ -145,9 +152,9 @@ class MultipleResultsFound(sa_exc.InvalidRequestError):
def _safe_cls_name(cls):
try:
- cls_name = '.'.join((cls.__module__, cls.__name__))
+ cls_name = ".".join((cls.__module__, cls.__name__))
except AttributeError:
- cls_name = getattr(cls, '__name__', None)
+ cls_name = getattr(cls, "__name__", None)
if cls_name is None:
cls_name = repr(cls)
return cls_name
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
index b03bb0a0d..2487cdb23 100644
--- a/lib/sqlalchemy/orm/identity.py
+++ b/lib/sqlalchemy/orm/identity.py
@@ -11,6 +11,7 @@ from .. import util
from .. import exc as sa_exc
from . import util as orm_util
+
class IdentityMap(object):
def __init__(self):
self._dict = {}
@@ -84,7 +85,6 @@ class IdentityMap(object):
class WeakInstanceDict(IdentityMap):
-
def __getitem__(self, key):
state = self._dict[key]
o = state.obj()
@@ -145,8 +145,9 @@ class WeakInstanceDict(IdentityMap):
raise sa_exc.InvalidRequestError(
"Can't attach instance "
"%s; another instance with key %s is already "
- "present in this session." % (
- orm_util.state_str(state), state.key))
+ "present in this session."
+ % (orm_util.state_str(state), state.key)
+ )
else:
return False
self._dict[key] = state
@@ -253,6 +254,7 @@ class StrongInstanceDict(IdentityMap):
"""
if util.py2k:
+
def itervalues(self):
return self._dict.itervalues()
@@ -282,8 +284,9 @@ class StrongInstanceDict(IdentityMap):
def contains_state(self, state):
return (
- state.key in self and
- attributes.instance_state(self[state.key]) is state)
+ state.key in self
+ and attributes.instance_state(self[state.key]) is state
+ )
def replace(self, state):
if state.key in self._dict:
@@ -303,8 +306,9 @@ class StrongInstanceDict(IdentityMap):
raise sa_exc.InvalidRequestError(
"Can't attach instance "
"%s; another instance with key %s is already "
- "present in this session." % (
- orm_util.state_str(state), state.key))
+ "present in this session."
+ % (orm_util.state_str(state), state.key)
+ )
return False
else:
self._dict[state.key] = state.obj()
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index d34326e0f..fa29c3233 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -59,11 +59,15 @@ class ClassManager(dict):
self.local_attrs = {}
self.originals = {}
- self._bases = [mgr for mgr in [
- manager_of_class(base)
- for base in self.class_.__bases__
- if isinstance(base, type)
- ] if mgr is not None]
+ self._bases = [
+ mgr
+ for mgr in [
+ manager_of_class(base)
+ for base in self.class_.__bases__
+ if isinstance(base, type)
+ ]
+ if mgr is not None
+ ]
for base in self._bases:
self.update(base)
@@ -78,12 +82,13 @@ class ClassManager(dict):
self.manage()
self._instrument_init()
- if '__del__' in class_.__dict__:
- util.warn("__del__() method on class %s will "
- "cause unreachable cycles and memory leaks, "
- "as SQLAlchemy instrumentation often creates "
- "reference cycles. Please remove this method." %
- class_)
+ if "__del__" in class_.__dict__:
+ util.warn(
+ "__del__() method on class %s will "
+ "cause unreachable cycles and memory leaks, "
+ "as SQLAlchemy instrumentation often creates "
+ "reference cycles. Please remove this method." % class_
+ )
def __hash__(self):
return id(self)
@@ -93,7 +98,7 @@ class ClassManager(dict):
@property
def is_mapped(self):
- return 'mapper' in self.__dict__
+ return "mapper" in self.__dict__
@_memoized_key_collection
def _all_key_set(self):
@@ -101,14 +106,19 @@ class ClassManager(dict):
@_memoized_key_collection
def _collection_impl_keys(self):
- return frozenset([
- attr.key for attr in self.values() if attr.impl.collection])
+ return frozenset(
+ [attr.key for attr in self.values() if attr.impl.collection]
+ )
@_memoized_key_collection
def _scalar_loader_impls(self):
- return frozenset([
- attr.impl for attr in
- self.values() if attr.impl.accepts_scalar_loader])
+ return frozenset(
+ [
+ attr.impl
+ for attr in self.values()
+ if attr.impl.accepts_scalar_loader
+ ]
+ )
@util.memoized_property
def mapper(self):
@@ -174,11 +184,11 @@ class ClassManager(dict):
# of such, since this adds method overhead.
self.original_init = self.class_.__init__
self.new_init = _generate_init(self.class_, self)
- self.install_member('__init__', self.new_init)
+ self.install_member("__init__", self.new_init)
def _uninstrument_init(self):
if self.new_init:
- self.uninstall_member('__init__')
+ self.uninstall_member("__init__")
self.new_init = None
@util.memoized_property
@@ -239,8 +249,9 @@ class ClassManager(dict):
yield m
def post_configure_attribute(self, key):
- _instrumentation_factory.dispatch.\
- attribute_instrument(self.class_, key, self[key])
+ _instrumentation_factory.dispatch.attribute_instrument(
+ self.class_, key, self[key]
+ )
def uninstrument_attribute(self, key, propagated=False):
if key not in self:
@@ -272,9 +283,10 @@ class ClassManager(dict):
def install_descriptor(self, key, inst):
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
- raise KeyError("%r: requested attribute name conflicts with "
- "instrumentation attribute of the same name." %
- key)
+ raise KeyError(
+ "%r: requested attribute name conflicts with "
+ "instrumentation attribute of the same name." % key
+ )
setattr(self.class_, key, inst)
def uninstall_descriptor(self, key):
@@ -282,9 +294,10 @@ class ClassManager(dict):
def install_member(self, key, implementation):
if key in (self.STATE_ATTR, self.MANAGER_ATTR):
- raise KeyError("%r: requested attribute name conflicts with "
- "instrumentation attribute of the same name." %
- key)
+ raise KeyError(
+ "%r: requested attribute name conflicts with "
+ "instrumentation attribute of the same name." % key
+ )
self.originals.setdefault(key, getattr(self.class_, key, None))
setattr(self.class_, key, implementation)
@@ -299,7 +312,8 @@ class ClassManager(dict):
def initialize_collection(self, key, state, factory):
user_data = factory()
adapter = collections.CollectionAdapter(
- self.get_impl(key), state, user_data)
+ self.get_impl(key), state, user_data
+ )
return adapter, user_data
def is_instrumented(self, key, search=False):
@@ -343,15 +357,15 @@ class ClassManager(dict):
"""
if hasattr(instance, self.STATE_ATTR):
return False
- elif self.class_ is not instance.__class__ and \
- self.is_mapped:
+ elif self.class_ is not instance.__class__ and self.is_mapped:
# this will create a new ClassManager for the
# subclass, without a mapper. This is likely a
# user error situation but allow the object
# to be constructed, so that it is usable
# in a non-ORM context at least.
- return self._subclass_manager(instance.__class__).\
- _new_state_if_none(instance)
+ return self._subclass_manager(
+ instance.__class__
+ )._new_state_if_none(instance)
else:
state = self._state_constructor(instance, self)
self._state_setter(instance, state)
@@ -371,8 +385,11 @@ class ClassManager(dict):
__nonzero__ = __bool__
def __repr__(self):
- return '<%s of %r at %x>' % (
- self.__class__.__name__, self.class_, id(self))
+ return "<%s of %r at %x>" % (
+ self.__class__.__name__,
+ self.class_,
+ id(self),
+ )
class _SerializeManager(object):
@@ -396,8 +413,8 @@ class _SerializeManager(object):
"Cannot deserialize object of type %r - "
"no mapper() has "
"been configured for this class within the current "
- "Python process!" %
- self.class_)
+ "Python process!" % self.class_,
+ )
elif manager.is_mapped and not manager.mapper.configured:
manager.mapper._configure_all()
@@ -447,6 +464,7 @@ class InstrumentationFactory(object):
if ClassManager.MANAGER_ATTR in class_.__dict__:
delattr(class_, ClassManager.MANAGER_ATTR)
+
# this attribute is replaced by sqlalchemy.ext.instrumentation
# when importred.
_instrumentation_factory = InstrumentationFactory()
@@ -488,8 +506,9 @@ def is_instrumented(instance, key):
applied directly to the class, i.e. no descriptors are required.
"""
- return manager_of_class(instance.__class__).\
- is_instrumented(key, search=True)
+ return manager_of_class(instance.__class__).is_instrumented(
+ key, search=True
+ )
def _generate_init(class_, class_manager):
@@ -518,15 +537,15 @@ def __init__(%(apply_pos)s):
func_text = func_body % func_vars
if util.py2k:
- func = getattr(original__init__, 'im_func', original__init__)
- func_defaults = getattr(func, 'func_defaults', None)
+ func = getattr(original__init__, "im_func", original__init__)
+ func_defaults = getattr(func, "func_defaults", None)
else:
- func_defaults = getattr(original__init__, '__defaults__', None)
- func_kw_defaults = getattr(original__init__, '__kwdefaults__', None)
+ func_defaults = getattr(original__init__, "__defaults__", None)
+ func_kw_defaults = getattr(original__init__, "__kwdefaults__", None)
env = locals().copy()
exec(func_text, env)
- __init__ = env['__init__']
+ __init__ = env["__init__"]
__init__.__doc__ = original__init__.__doc__
__init__._sa_original_init = original__init__
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 80d0a6303..d7e70c5d7 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -22,8 +22,15 @@ from __future__ import absolute_import
from .. import util
from ..sql import operators
-from .base import (ONETOMANY, MANYTOONE, MANYTOMANY,
- EXT_CONTINUE, EXT_STOP, EXT_SKIP, NOT_EXTENSION)
+from .base import (
+ ONETOMANY,
+ MANYTOONE,
+ MANYTOMANY,
+ EXT_CONTINUE,
+ EXT_STOP,
+ EXT_SKIP,
+ NOT_EXTENSION,
+)
from .base import InspectionAttr, InspectionAttrInfo, _MappedAttribute
import collections
from .. import inspect
@@ -33,21 +40,21 @@ from . import path_registry
MapperExtension = SessionExtension = AttributeExtension = None
__all__ = (
- 'AttributeExtension',
- 'EXT_CONTINUE',
- 'EXT_STOP',
- 'EXT_SKIP',
- 'ONETOMANY',
- 'MANYTOMANY',
- 'MANYTOONE',
- 'NOT_EXTENSION',
- 'LoaderStrategy',
- 'MapperExtension',
- 'MapperOption',
- 'MapperProperty',
- 'PropComparator',
- 'SessionExtension',
- 'StrategizedProperty',
+ "AttributeExtension",
+ "EXT_CONTINUE",
+ "EXT_STOP",
+ "EXT_SKIP",
+ "ONETOMANY",
+ "MANYTOMANY",
+ "MANYTOONE",
+ "NOT_EXTENSION",
+ "LoaderStrategy",
+ "MapperExtension",
+ "MapperOption",
+ "MapperProperty",
+ "PropComparator",
+ "SessionExtension",
+ "StrategizedProperty",
)
@@ -64,8 +71,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
"""
__slots__ = (
- '_configure_started', '_configure_finished', 'parent', 'key',
- 'info'
+ "_configure_started",
+ "_configure_finished",
+ "parent",
+ "key",
+ "info",
)
cascade = frozenset()
@@ -118,15 +128,17 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
"""
- def create_row_processor(self, context, path,
- mapper, result, adapter, populators):
+ def create_row_processor(
+ self, context, path, mapper, result, adapter, populators
+ ):
"""Produce row processing functions and append to the given
set of populators lists.
"""
- def cascade_iterator(self, type_, state, visited_instances=None,
- halt_on=None):
+ def cascade_iterator(
+ self, type_, state, visited_instances=None, halt_on=None
+ ):
"""Iterate through instances related to the given instance for
a particular 'cascade', starting with this MapperProperty.
@@ -234,17 +246,28 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
"""
- def merge(self, session, source_state, source_dict, dest_state,
- dest_dict, load, _recursive, _resolve_conflict_map):
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
"""Merge the attribute represented by this ``MapperProperty``
from source to destination object.
"""
def __repr__(self):
- return '<%s at 0x%x; %s>' % (
+ return "<%s at 0x%x; %s>" % (
self.__class__.__name__,
- id(self), getattr(self, 'key', 'no key'))
+ id(self),
+ getattr(self, "key", "no key"),
+ )
class PropComparator(operators.ColumnOperators):
@@ -335,7 +358,7 @@ class PropComparator(operators.ColumnOperators):
"""
- __slots__ = 'prop', 'property', '_parententity', '_adapt_to_entity'
+ __slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
def __init__(self, prop, parentmapper, adapt_to_entity=None):
self.prop = self.property = prop
@@ -467,21 +490,27 @@ class StrategizedProperty(MapperProperty):
"""
__slots__ = (
- '_strategies', 'strategy',
- '_wildcard_token', '_default_path_loader_key'
+ "_strategies",
+ "strategy",
+ "_wildcard_token",
+ "_default_path_loader_key",
)
strategy_wildcard_key = None
def _memoized_attr__wildcard_token(self):
- return ("%s:%s" % (
- self.strategy_wildcard_key, path_registry._WILDCARD_TOKEN), )
+ return (
+ "%s:%s"
+ % (self.strategy_wildcard_key, path_registry._WILDCARD_TOKEN),
+ )
def _memoized_attr__default_path_loader_key(self):
return (
"loader",
- ("%s:%s" % (
- self.strategy_wildcard_key, path_registry._DEFAULT_TOKEN), )
+ (
+ "%s:%s"
+ % (self.strategy_wildcard_key, path_registry._DEFAULT_TOKEN),
+ ),
)
def _get_context_loader(self, context, path):
@@ -496,7 +525,7 @@ class StrategizedProperty(MapperProperty):
for path_key in (
search_path._loader_key,
search_path._wildcard_path_loader_key,
- search_path._default_path_loader_key
+ search_path._default_path_loader_key,
):
if path_key in context.attributes:
load = context.attributes[path_key]
@@ -509,12 +538,12 @@ class StrategizedProperty(MapperProperty):
return self._strategies[key]
except KeyError:
cls = self._strategy_lookup(*key)
- self._strategies[key] = self._strategies[
- cls] = strategy = cls(self, key)
+ self._strategies[key] = self._strategies[cls] = strategy = cls(
+ self, key
+ )
return strategy
- def setup(
- self, context, entity, path, adapter, **kwargs):
+ def setup(self, context, entity, path, adapter, **kwargs):
loader = self._get_context_loader(context, path)
if loader and loader.strategy:
strat = self._get_strategy(loader.strategy)
@@ -523,24 +552,26 @@ class StrategizedProperty(MapperProperty):
strat.setup_query(context, entity, path, loader, adapter, **kwargs)
def create_row_processor(
- self, context, path, mapper,
- result, adapter, populators):
+ self, context, path, mapper, result, adapter, populators
+ ):
loader = self._get_context_loader(context, path)
if loader and loader.strategy:
strat = self._get_strategy(loader.strategy)
else:
strat = self.strategy
strat.create_row_processor(
- context, path, loader,
- mapper, result, adapter, populators)
+ context, path, loader, mapper, result, adapter, populators
+ )
def do_init(self):
self._strategies = {}
self.strategy = self._get_strategy(self.strategy_key)
def post_instrument_class(self, mapper):
- if not self.parent.non_primary and \
- not mapper.class_manager._attr_has_impl(self.key):
+ if (
+ not self.parent.non_primary
+ and not mapper.class_manager._attr_has_impl(self.key)
+ ):
self.strategy.init_class_attribute(mapper)
_all_strategies = collections.defaultdict(dict)
@@ -550,12 +581,13 @@ class StrategizedProperty(MapperProperty):
def decorate(dec_cls):
# ensure each subclass of the strategy has its
# own _strategy_keys collection
- if '_strategy_keys' not in dec_cls.__dict__:
+ if "_strategy_keys" not in dec_cls.__dict__:
dec_cls._strategy_keys = []
key = tuple(sorted(kw.items()))
cls._all_strategies[cls][key] = dec_cls
dec_cls._strategy_keys.append(key)
return dec_cls
+
return decorate
@classmethod
@@ -671,8 +703,14 @@ class LoaderStrategy(object):
"""
- __slots__ = 'parent_property', 'is_class_level', 'parent', 'key', \
- 'strategy_key', 'strategy_opts'
+ __slots__ = (
+ "parent_property",
+ "is_class_level",
+ "parent",
+ "key",
+ "strategy_key",
+ "strategy_opts",
+ )
def __init__(self, parent, strategy_key):
self.parent_property = parent
@@ -695,8 +733,9 @@ class LoaderStrategy(object):
"""
- def create_row_processor(self, context, path, loadopt, mapper,
- result, adapter, populators):
+ def create_row_processor(
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
"""Establish row processing functions for a given QueryContext.
This method fulfills the contract specified by
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 0a6f8023a..96eddcb32 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -37,32 +37,35 @@ def instances(query, cursor, context):
filtered = query._has_mapper_entities
- single_entity = not query._only_return_tuples and \
- len(query._entities) == 1 and \
- query._entities[0].supports_single_entity
+ single_entity = (
+ not query._only_return_tuples
+ and len(query._entities) == 1
+ and query._entities[0].supports_single_entity
+ )
if filtered:
if single_entity:
filter_fn = id
else:
+
def filter_fn(row):
return tuple(
- id(item)
- if ent.use_id_for_hash
- else item
+ id(item) if ent.use_id_for_hash else item
for ent, item in zip(query._entities, row)
)
try:
- (process, labels) = \
- list(zip(*[
- query_entity.row_processor(query,
- context, cursor)
- for query_entity in query._entities
- ]))
+ (process, labels) = list(
+ zip(
+ *[
+ query_entity.row_processor(query, context, cursor)
+ for query_entity in query._entities
+ ]
+ )
+ )
if not single_entity:
- keyed_tuple = util.lightweight_named_tuple('result', labels)
+ keyed_tuple = util.lightweight_named_tuple("result", labels)
while True:
context.partials = {}
@@ -78,11 +81,12 @@ def instances(query, cursor, context):
proc = process[0]
rows = [proc(row) for row in fetch]
else:
- rows = [keyed_tuple([proc(row) for proc in process])
- for row in fetch]
+ rows = [
+ keyed_tuple([proc(row) for proc in process])
+ for row in fetch
+ ]
- for path, post_load in \
- context.post_load_paths.items():
+ for path, post_load in context.post_load_paths.items():
post_load.invoke(context, path)
if filtered:
@@ -113,19 +117,27 @@ def merge_result(querylib, query, iterator, load=True):
single_entity = len(query._entities) == 1
if single_entity:
if isinstance(query._entities[0], querylib._MapperEntity):
- result = [session._merge(
- attributes.instance_state(instance),
- attributes.instance_dict(instance),
- load=load, _recursive={}, _resolve_conflict_map={})
- for instance in iterator]
+ result = [
+ session._merge(
+ attributes.instance_state(instance),
+ attributes.instance_dict(instance),
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
+ for instance in iterator
+ ]
else:
result = list(iterator)
else:
- mapped_entities = [i for i, e in enumerate(query._entities)
- if isinstance(e, querylib._MapperEntity)]
+ mapped_entities = [
+ i
+ for i, e in enumerate(query._entities)
+ if isinstance(e, querylib._MapperEntity)
+ ]
result = []
keys = [ent._label_name for ent in query._entities]
- keyed_tuple = util.lightweight_named_tuple('result', keys)
+ keyed_tuple = util.lightweight_named_tuple("result", keys)
for row in iterator:
newrow = list(row)
for i in mapped_entities:
@@ -133,7 +145,10 @@ def merge_result(querylib, query, iterator, load=True):
newrow[i] = session._merge(
attributes.instance_state(newrow[i]),
attributes.instance_dict(newrow[i]),
- load=load, _recursive={}, _resolve_conflict_map={})
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
result.append(keyed_tuple(newrow))
return iter(result)
@@ -170,9 +185,9 @@ def get_from_identity(session, key, passive):
return None
-def load_on_ident(query, key,
- refresh_state=None, with_for_update=None,
- only_load_props=None):
+def load_on_ident(
+ query, key, refresh_state=None, with_for_update=None, only_load_props=None
+):
"""Load the given identity key from the database."""
if key is not None:
@@ -182,16 +197,23 @@ def load_on_ident(query, key,
ident = identity_token = None
return load_on_pk_identity(
- query, ident, refresh_state=refresh_state,
+ query,
+ ident,
+ refresh_state=refresh_state,
with_for_update=with_for_update,
only_load_props=only_load_props,
- identity_token=identity_token
+ identity_token=identity_token,
)
-def load_on_pk_identity(query, primary_key_identity,
- refresh_state=None, with_for_update=None,
- only_load_props=None, identity_token=None):
+def load_on_pk_identity(
+ query,
+ primary_key_identity,
+ refresh_state=None,
+ with_for_update=None,
+ only_load_props=None,
+ identity_token=None,
+):
"""Load the given primary key identity from the database."""
@@ -209,22 +231,28 @@ def load_on_pk_identity(query, primary_key_identity,
# None present in ident - turn those comparisons
# into "IS NULL"
if None in primary_key_identity:
- nones = set([
- _get_params[col].key for col, value in
- zip(mapper.primary_key, primary_key_identity)
- if value is None
- ])
- _get_clause = sql_util.adapt_criterion_to_null(
- _get_clause, nones)
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
+ _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones)
_get_clause = q._adapt_clause(_get_clause, True, False)
q._criterion = _get_clause
- params = dict([
- (_get_params[primary_key].key, id_val)
- for id_val, primary_key
- in zip(primary_key_identity, mapper.primary_key)
- ])
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
q._params = params
@@ -243,7 +271,8 @@ def load_on_pk_identity(query, primary_key_identity,
version_check=version_check,
only_load_props=only_load_props,
refresh_state=refresh_state,
- identity_token=identity_token)
+ identity_token=identity_token,
+ )
q._order_by = None
try:
@@ -253,27 +282,31 @@ def load_on_pk_identity(query, primary_key_identity,
def _setup_entity_query(
- context, mapper, query_entity,
- path, adapter, column_collection,
- with_polymorphic=None, only_load_props=None,
- polymorphic_discriminator=None, **kw):
+ context,
+ mapper,
+ query_entity,
+ path,
+ adapter,
+ column_collection,
+ with_polymorphic=None,
+ only_load_props=None,
+ polymorphic_discriminator=None,
+ **kw
+):
if with_polymorphic:
poly_properties = mapper._iterate_polymorphic_properties(
- with_polymorphic)
+ with_polymorphic
+ )
else:
poly_properties = mapper._polymorphic_properties
quick_populators = {}
- path.set(
- context.attributes,
- "memoized_setups",
- quick_populators)
+ path.set(context.attributes, "memoized_setups", quick_populators)
for value in poly_properties:
- if only_load_props and \
- value.key not in only_load_props:
+ if only_load_props and value.key not in only_load_props:
continue
value.setup(
context,
@@ -286,9 +319,10 @@ def _setup_entity_query(
**kw
)
- if polymorphic_discriminator is not None and \
- polymorphic_discriminator \
- is not mapper.polymorphic_on:
+ if (
+ polymorphic_discriminator is not None
+ and polymorphic_discriminator is not mapper.polymorphic_on
+ ):
if adapter:
pd = adapter.columns[polymorphic_discriminator]
@@ -298,10 +332,16 @@ def _setup_entity_query(
def _instance_processor(
- mapper, context, result, path, adapter,
- only_load_props=None, refresh_state=None,
- polymorphic_discriminator=None,
- _polymorphic_from=None):
+ mapper,
+ context,
+ result,
+ path,
+ adapter,
+ only_load_props=None,
+ refresh_state=None,
+ polymorphic_discriminator=None,
+ _polymorphic_from=None,
+):
"""Produce a mapper level row processor callable
which processes rows into mapped instances."""
@@ -322,11 +362,11 @@ def _instance_processor(
props = mapper._prop_set
if only_load_props is not None:
- props = props.intersection(
- mapper._props[k] for k in only_load_props)
+ props = props.intersection(mapper._props[k] for k in only_load_props)
quick_populators = path.get(
- context.attributes, "memoized_setups", _none_set)
+ context.attributes, "memoized_setups", _none_set
+ )
for prop in props:
if prop in quick_populators:
@@ -334,7 +374,8 @@ def _instance_processor(
col = quick_populators[prop]
if col is _DEFER_FOR_STATE:
populators["new"].append(
- (prop.key, prop._deferred_column_loader))
+ (prop.key, prop._deferred_column_loader)
+ )
elif col is _SET_DEFERRED_EXPIRED:
# note that in this path, we are no longer
# searching in the result to see if the column might
@@ -366,14 +407,19 @@ def _instance_processor(
# will iterate through all of its columns
# to see if one fits
prop.create_row_processor(
- context, path, mapper, result, adapter, populators)
+ context, path, mapper, result, adapter, populators
+ )
else:
prop.create_row_processor(
- context, path, mapper, result, adapter, populators)
+ context, path, mapper, result, adapter, populators
+ )
propagate_options = context.propagate_options
- load_path = context.query._current_path + path \
- if context.query._current_path.path else path
+ load_path = (
+ context.query._current_path + path
+ if context.query._current_path.path
+ else path
+ )
session_identity_map = context.session.identity_map
@@ -391,18 +437,18 @@ def _instance_processor(
identity_token = context.identity_token
if not refresh_state and _polymorphic_from is not None:
- key = ('loader', path.path)
- if (
- key in context.attributes and
- context.attributes[key].strategy ==
- (('selectinload_polymorphic', True), )
+ key = ("loader", path.path)
+ if key in context.attributes and context.attributes[key].strategy == (
+ ("selectinload_polymorphic", True),
):
selectin_load_via = mapper._should_selectin_load(
- context.attributes[key].local_opts['entities'],
- _polymorphic_from)
+ context.attributes[key].local_opts["entities"],
+ _polymorphic_from,
+ )
else:
selectin_load_via = mapper._should_selectin_load(
- None, _polymorphic_from)
+ None, _polymorphic_from
+ )
if selectin_load_via and selectin_load_via is not _polymorphic_from:
# only_load_props goes w/ refresh_state only, and in a refresh
@@ -413,9 +459,13 @@ def _instance_processor(
callable_ = _load_subclass_via_in(context, path, selectin_load_via)
PostLoad.callable_for_path(
- context, load_path, selectin_load_via.mapper,
+ context,
+ load_path,
+ selectin_load_via.mapper,
+ selectin_load_via,
+ callable_,
selectin_load_via,
- callable_, selectin_load_via)
+ )
post_load = PostLoad.for_context(context, load_path, only_load_props)
@@ -425,8 +475,9 @@ def _instance_processor(
# super-rare condition; a refresh is being called
# on a non-instance-key instance; this is meant to only
# occur within a flush()
- refresh_identity_key = \
- mapper._identity_key_from_state(refresh_state)
+ refresh_identity_key = mapper._identity_key_from_state(
+ refresh_state
+ )
else:
refresh_identity_key = None
@@ -452,7 +503,7 @@ def _instance_processor(
identitykey = (
identity_class,
tuple([row[column] for column in pk_cols]),
- identity_token
+ identity_token,
)
instance = session_identity_map.get(identitykey)
@@ -507,8 +558,16 @@ def _instance_processor(
state.load_path = load_path
_populate_full(
- context, row, state, dict_, isnew, load_path,
- loaded_instance, populate_existing, populators)
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ loaded_instance,
+ populate_existing,
+ populators,
+ )
if isnew:
if loaded_instance:
@@ -518,7 +577,8 @@ def _instance_processor(
loaded_as_persistent(context.session, state.obj())
elif refresh_evt:
state.manager.dispatch.refresh(
- state, context, only_load_props)
+ state, context, only_load_props
+ )
if populate_existing or state.modified:
if refresh_state and only_load_props:
@@ -542,13 +602,19 @@ def _instance_processor(
# and add to the "context.partials" collection.
to_load = _populate_partial(
- context, row, state, dict_, isnew, load_path,
- unloaded, populators)
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ unloaded,
+ populators,
+ )
if isnew:
if refresh_evt:
- state.manager.dispatch.refresh(
- state, context, to_load)
+ state.manager.dispatch.refresh(state, context, to_load)
state._commit(dict_, to_load)
@@ -561,8 +627,14 @@ def _instance_processor(
# if we are doing polymorphic, dispatch to a different _instance()
# method specific to the subclass mapper
_instance = _decorate_polymorphic_switch(
- _instance, context, mapper, result, path,
- polymorphic_discriminator, adapter)
+ _instance,
+ context,
+ mapper,
+ result,
+ path,
+ polymorphic_discriminator,
+ adapter,
+ )
return _instance
@@ -581,14 +653,13 @@ def _load_subclass_via_in(context, path, entity):
orig_query = context.query
q2 = q._with_lazyload_options(
- (enable_opt, ) + orig_query._with_options + (disable_opt, ),
- path.parent, cache_path=path
+ (enable_opt,) + orig_query._with_options + (disable_opt,),
+ path.parent,
+ cache_path=path,
)
if orig_query._populate_existing:
- q2.add_criteria(
- lambda q: q.populate_existing()
- )
+ q2.add_criteria(lambda q: q.populate_existing())
q2(context.session).params(
primary_keys=[
@@ -601,8 +672,16 @@ def _load_subclass_via_in(context, path, entity):
def _populate_full(
- context, row, state, dict_, isnew, load_path,
- loaded_instance, populate_existing, populators):
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ loaded_instance,
+ populate_existing,
+ populators,
+):
if isnew:
# first time we are seeing a row with this identity.
state.runid = context.runid
@@ -650,8 +729,8 @@ def _populate_full(
def _populate_partial(
- context, row, state, dict_, isnew, load_path,
- unloaded, populators):
+ context, row, state, dict_, isnew, load_path, unloaded, populators
+):
if not isnew:
to_load = context.partials[state]
@@ -693,19 +772,32 @@ def _validate_version_id(mapper, state, dict_, row, adapter):
if adapter:
version_id_col = adapter.columns[version_id_col]
- if mapper._get_state_attr_by_column(
- state, dict_, mapper.version_id_col) != row[version_id_col]:
+ if (
+ mapper._get_state_attr_by_column(state, dict_, mapper.version_id_col)
+ != row[version_id_col]
+ ):
raise orm_exc.StaleDataError(
"Instance '%s' has version id '%s' which "
"does not match database-loaded version id '%s'."
- % (state_str(state), mapper._get_state_attr_by_column(
- state, dict_, mapper.version_id_col),
- row[version_id_col]))
+ % (
+ state_str(state),
+ mapper._get_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ ),
+ row[version_id_col],
+ )
+ )
def _decorate_polymorphic_switch(
- instance_fn, context, mapper, result, path,
- polymorphic_discriminator, adapter):
+ instance_fn,
+ context,
+ mapper,
+ result,
+ path,
+ polymorphic_discriminator,
+ adapter,
+):
if polymorphic_discriminator is not None:
polymorphic_on = polymorphic_discriminator
else:
@@ -721,19 +813,22 @@ def _decorate_polymorphic_switch(
sub_mapper = mapper.polymorphic_map[discriminator]
except KeyError:
raise AssertionError(
- "No such polymorphic_identity %r is defined" %
- discriminator)
+ "No such polymorphic_identity %r is defined" % discriminator
+ )
else:
if sub_mapper is mapper:
return None
return _instance_processor(
- sub_mapper, context, result,
- path, adapter, _polymorphic_from=mapper)
+ sub_mapper,
+ context,
+ result,
+ path,
+ adapter,
+ _polymorphic_from=mapper,
+ )
- polymorphic_instances = util.PopulateDict(
- configure_subclass_mapper
- )
+ polymorphic_instances = util.PopulateDict(configure_subclass_mapper)
def polymorphic_instance(row):
discriminator = row[polymorphic_on]
@@ -742,6 +837,7 @@ def _decorate_polymorphic_switch(
if _instance:
return _instance(row)
return instance_fn(row)
+
return polymorphic_instance
@@ -749,7 +845,8 @@ class PostLoad(object):
"""Track loaders and states for "post load" operations.
"""
- __slots__ = 'loaders', 'states', 'load_keys'
+
+ __slots__ = "loaders", "states", "load_keys"
def __init__(self):
self.loaders = {}
@@ -770,8 +867,7 @@ class PostLoad(object):
for token, limit_to_mapper, loader, arg, kw in self.loaders.values():
states = [
(state, overwrite)
- for state, overwrite
- in self.states.items()
+ for state, overwrite in self.states.items()
if state.manager.mapper.isa(limit_to_mapper)
]
if states:
@@ -787,13 +883,15 @@ class PostLoad(object):
@classmethod
def path_exists(self, context, path, key):
- return path.path in context.post_load_paths and \
- key in context.post_load_paths[path.path].loaders
+ return (
+ path.path in context.post_load_paths
+ and key in context.post_load_paths[path.path].loaders
+ )
@classmethod
def callable_for_path(
- cls, context, path, limit_to_mapper, token,
- loader_callable, *arg, **kw):
+ cls, context, path, limit_to_mapper, token, loader_callable, *arg, **kw
+ ):
if path.path in context.post_load_paths:
pl = context.post_load_paths[path.path]
else:
@@ -809,8 +907,8 @@ def load_scalar_attributes(mapper, state, attribute_names):
if not session:
raise orm_exc.DetachedInstanceError(
"Instance %s is not bound to a Session; "
- "attribute refresh operation cannot proceed" %
- (state_str(state)))
+ "attribute refresh operation cannot proceed" % (state_str(state))
+ )
has_key = bool(state.key)
@@ -833,13 +931,12 @@ def load_scalar_attributes(mapper, state, attribute_names):
statement = mapper._optimized_get_statement(state, attribute_names)
if statement is not None:
result = load_on_ident(
- session.query(mapper).
- options(
- strategy_options.Load(mapper).undefer("*")
- ).from_statement(statement),
+ session.query(mapper)
+ .options(strategy_options.Load(mapper).undefer("*"))
+ .from_statement(statement),
None,
only_load_props=attribute_names,
- refresh_state=state
+ refresh_state=state,
)
if result is False:
@@ -850,30 +947,34 @@ def load_scalar_attributes(mapper, state, attribute_names):
# object is becoming persistent but hasn't yet been assigned
# an identity_key.
# check here to ensure we have the attrs we need.
- pk_attrs = [mapper._columntoproperty[col].key
- for col in mapper.primary_key]
+ pk_attrs = [
+ mapper._columntoproperty[col].key for col in mapper.primary_key
+ ]
if state.expired_attributes.intersection(pk_attrs):
raise sa_exc.InvalidRequestError(
"Instance %s cannot be refreshed - it's not "
" persistent and does not "
- "contain a full primary key." % state_str(state))
+ "contain a full primary key." % state_str(state)
+ )
identity_key = mapper._identity_key_from_state(state)
- if (_none_set.issubset(identity_key) and
- not mapper.allow_partial_pks) or \
- _none_set.issuperset(identity_key):
+ if (
+ _none_set.issubset(identity_key) and not mapper.allow_partial_pks
+ ) or _none_set.issuperset(identity_key):
util.warn_limited(
"Instance %s to be refreshed doesn't "
"contain a full primary key - can't be refreshed "
"(and shouldn't be expired, either).",
- state_str(state))
+ state_str(state),
+ )
return
result = load_on_ident(
session.query(mapper),
identity_key,
refresh_state=state,
- only_load_props=attribute_names)
+ only_load_props=attribute_names,
+ )
# if instance is pending, a refresh operation
# may not complete (even if PK attributes are assigned)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index fa731f729..ea8890788 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -26,12 +26,21 @@ from ..sql import expression, visitors, operators, util as sql_util
from . import instrumentation, attributes, exc as orm_exc, loading
from . import properties
from . import util as orm_util
-from .interfaces import MapperProperty, InspectionAttr, _MappedAttribute, \
- EXT_SKIP
-
-
-from .base import _class_to_mapper, _state_mapper, class_mapper, \
- state_str, _INSTRUMENTOR
+from .interfaces import (
+ MapperProperty,
+ InspectionAttr,
+ _MappedAttribute,
+ EXT_SKIP,
+)
+
+
+from .base import (
+ _class_to_mapper,
+ _state_mapper,
+ class_mapper,
+ state_str,
+ _INSTRUMENTOR,
+)
from .path_registry import PathRegistry
import sys
@@ -46,7 +55,7 @@ _memoized_configured_property = util.group_expirable_memoized_property()
# a constant returned by _get_attr_by_column to indicate
# this mapper is not handling an attribute for a particular
# column
-NO_ATTRIBUTE = util.symbol('NO_ATTRIBUTE')
+NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE")
# lock used to synchronize the "mapper configure" step
_CONFIGURE_MUTEX = util.threading.RLock()
@@ -90,38 +99,39 @@ class Mapper(InspectionAttr):
_new_mappers = False
_dispose_called = False
- def __init__(self,
- class_,
- local_table=None,
- properties=None,
- primary_key=None,
- non_primary=False,
- inherits=None,
- inherit_condition=None,
- inherit_foreign_keys=None,
- extension=None,
- order_by=False,
- always_refresh=False,
- version_id_col=None,
- version_id_generator=None,
- polymorphic_on=None,
- _polymorphic_map=None,
- polymorphic_identity=None,
- concrete=False,
- with_polymorphic=None,
- polymorphic_load=None,
- allow_partial_pks=True,
- batch=True,
- column_prefix=None,
- include_properties=None,
- exclude_properties=None,
- passive_updates=True,
- passive_deletes=False,
- confirm_deleted_rows=True,
- eager_defaults=False,
- legacy_is_orphan=False,
- _compiled_cache_size=100,
- ):
+ def __init__(
+ self,
+ class_,
+ local_table=None,
+ properties=None,
+ primary_key=None,
+ non_primary=False,
+ inherits=None,
+ inherit_condition=None,
+ inherit_foreign_keys=None,
+ extension=None,
+ order_by=False,
+ always_refresh=False,
+ version_id_col=None,
+ version_id_generator=None,
+ polymorphic_on=None,
+ _polymorphic_map=None,
+ polymorphic_identity=None,
+ concrete=False,
+ with_polymorphic=None,
+ polymorphic_load=None,
+ allow_partial_pks=True,
+ batch=True,
+ column_prefix=None,
+ include_properties=None,
+ exclude_properties=None,
+ passive_updates=True,
+ passive_deletes=False,
+ confirm_deleted_rows=True,
+ eager_defaults=False,
+ legacy_is_orphan=False,
+ _compiled_cache_size=100,
+ ):
r"""Return a new :class:`~.Mapper` object.
This function is typically used behind the scenes
@@ -588,7 +598,7 @@ class Mapper(InspectionAttr):
"""
- self.class_ = util.assert_arg_type(class_, type, 'class_')
+ self.class_ = util.assert_arg_type(class_, type, "class_")
self.class_manager = None
@@ -600,7 +610,8 @@ class Mapper(InspectionAttr):
util.warn_deprecated(
"Mapper.order_by is deprecated."
"Use Query.order_by() in order to affect the ordering of ORM "
- "result sets.")
+ "result sets."
+ )
else:
self.order_by = order_by
@@ -631,7 +642,8 @@ class Mapper(InspectionAttr):
self.eager_defaults = eager_defaults
self.column_prefix = column_prefix
self.polymorphic_on = expression._clause_element_as_expr(
- polymorphic_on)
+ polymorphic_on
+ )
self._dependency_processors = []
self.validators = util.immutabledict()
self.passive_updates = passive_updates
@@ -974,14 +986,16 @@ class Mapper(InspectionAttr):
self.inherits = class_mapper(self.inherits, configure=False)
if not issubclass(self.class_, self.inherits.class_):
raise sa_exc.ArgumentError(
- "Class '%s' does not inherit from '%s'" %
- (self.class_.__name__, self.inherits.class_.__name__))
+ "Class '%s' does not inherit from '%s'"
+ % (self.class_.__name__, self.inherits.class_.__name__)
+ )
if self.non_primary != self.inherits.non_primary:
np = not self.non_primary and "primary" or "non-primary"
raise sa_exc.ArgumentError(
"Inheritance of %s mapper for class '%s' is "
- "only allowed from a %s mapper" %
- (np, self.class_.__name__, np))
+ "only allowed from a %s mapper"
+ % (np, self.class_.__name__, np)
+ )
# inherit_condition is optional.
if self.local_table is None:
self.local_table = self.inherits.local_table
@@ -1000,18 +1014,19 @@ class Mapper(InspectionAttr):
# full table which could pull in other stuff we don't
# want (allows test/inheritance.InheritTest4 to pass)
self.inherit_condition = sql_util.join_condition(
- self.inherits.local_table,
- self.local_table)
+ self.inherits.local_table, self.local_table
+ )
self.mapped_table = sql.join(
self.inherits.mapped_table,
self.local_table,
- self.inherit_condition)
+ self.inherit_condition,
+ )
fks = util.to_set(self.inherit_foreign_keys)
- self._inherits_equated_pairs = \
- sql_util.criterion_as_pairs(
- self.mapped_table.onclause,
- consider_as_foreign_keys=fks)
+ self._inherits_equated_pairs = sql_util.criterion_as_pairs(
+ self.mapped_table.onclause,
+ consider_as_foreign_keys=fks,
+ )
else:
self.mapped_table = self.local_table
@@ -1023,21 +1038,27 @@ class Mapper(InspectionAttr):
if self.version_id_col is None:
self.version_id_col = self.inherits.version_id_col
self.version_id_generator = self.inherits.version_id_generator
- elif self.inherits.version_id_col is not None and \
- self.version_id_col is not self.inherits.version_id_col:
+ elif (
+ self.inherits.version_id_col is not None
+ and self.version_id_col is not self.inherits.version_id_col
+ ):
util.warn(
"Inheriting version_id_col '%s' does not match inherited "
"version_id_col '%s' and will not automatically populate "
"the inherited versioning column. "
"version_id_col should only be specified on "
- "the base-most mapper that includes versioning." %
- (self.version_id_col.description,
- self.inherits.version_id_col.description)
+ "the base-most mapper that includes versioning."
+ % (
+ self.version_id_col.description,
+ self.inherits.version_id_col.description,
+ )
)
- if self.order_by is False and \
- not self.concrete and \
- self.inherits.order_by is not False:
+ if (
+ self.order_by is False
+ and not self.concrete
+ and self.inherits.order_by is not False
+ ):
self.order_by = self.inherits.order_by
self.polymorphic_map = self.inherits.polymorphic_map
@@ -1045,8 +1066,9 @@ class Mapper(InspectionAttr):
self.inherits._inheriting_mappers.append(self)
self.base_mapper = self.inherits.base_mapper
self.passive_updates = self.inherits.passive_updates
- self.passive_deletes = self.inherits.passive_deletes or \
- self.passive_deletes
+ self.passive_deletes = (
+ self.inherits.passive_deletes or self.passive_deletes
+ )
self._all_tables = self.inherits._all_tables
if self.polymorphic_identity is not None:
@@ -1054,25 +1076,30 @@ class Mapper(InspectionAttr):
util.warn(
"Reassigning polymorphic association for identity %r "
"from %r to %r: Check for duplicate use of %r as "
- "value for polymorphic_identity." %
- (self.polymorphic_identity,
- self.polymorphic_map[self.polymorphic_identity],
- self, self.polymorphic_identity)
+ "value for polymorphic_identity."
+ % (
+ self.polymorphic_identity,
+ self.polymorphic_map[self.polymorphic_identity],
+ self,
+ self.polymorphic_identity,
+ )
)
self.polymorphic_map[self.polymorphic_identity] = self
if self.polymorphic_load and self.concrete:
raise exc.ArgumentError(
"polymorphic_load is not currently supported "
- "with concrete table inheritance")
- if self.polymorphic_load == 'inline':
+ "with concrete table inheritance"
+ )
+ if self.polymorphic_load == "inline":
self.inherits._add_with_polymorphic_subclass(self)
- elif self.polymorphic_load == 'selectin':
+ elif self.polymorphic_load == "selectin":
pass
elif self.polymorphic_load is not None:
raise sa_exc.ArgumentError(
- "unknown argument for polymorphic_load: %r" %
- self.polymorphic_load)
+ "unknown argument for polymorphic_load: %r"
+ % self.polymorphic_load
+ )
else:
self._all_tables = set()
@@ -1084,15 +1111,16 @@ class Mapper(InspectionAttr):
if self.mapped_table is None:
raise sa_exc.ArgumentError(
- "Mapper '%s' does not have a mapped_table specified."
- % self)
+ "Mapper '%s' does not have a mapped_table specified." % self
+ )
def _set_with_polymorphic(self, with_polymorphic):
- if with_polymorphic == '*':
- self.with_polymorphic = ('*', None)
+ if with_polymorphic == "*":
+ self.with_polymorphic = ("*", None)
elif isinstance(with_polymorphic, (tuple, list)):
if isinstance(
- with_polymorphic[0], util.string_types + (tuple, list)):
+ with_polymorphic[0], util.string_types + (tuple, list)
+ ):
self.with_polymorphic = with_polymorphic
else:
self.with_polymorphic = (with_polymorphic, None)
@@ -1109,11 +1137,13 @@ class Mapper(InspectionAttr):
"SELECT from a subquery that does not have an alias."
)
- if self.with_polymorphic and \
- isinstance(self.with_polymorphic[1],
- expression.SelectBase):
- self.with_polymorphic = (self.with_polymorphic[0],
- self.with_polymorphic[1].alias())
+ if self.with_polymorphic and isinstance(
+ self.with_polymorphic[1], expression.SelectBase
+ ):
+ self.with_polymorphic = (
+ self.with_polymorphic[0],
+ self.with_polymorphic[1].alias(),
+ )
if self.configured:
self._expire_memoizations()
@@ -1122,12 +1152,9 @@ class Mapper(InspectionAttr):
subcl = mapper.class_
if self.with_polymorphic is None:
self._set_with_polymorphic((subcl,))
- elif self.with_polymorphic[0] != '*':
+ elif self.with_polymorphic[0] != "*":
self._set_with_polymorphic(
- (
- self.with_polymorphic[0] + (subcl, ),
- self.with_polymorphic[1]
- )
+ (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1])
)
def _set_concrete_base(self, mapper):
@@ -1152,9 +1179,9 @@ class Mapper(InspectionAttr):
self._all_tables = self.inherits._all_tables
for key, prop in mapper._props.items():
- if key not in self._props and \
- not self._should_exclude(key, key, local=False,
- column=None):
+ if key not in self._props and not self._should_exclude(
+ key, key, local=False, column=None
+ ):
self._adapt_inherited_property(key, prop, False)
def _set_polymorphic_on(self, polymorphic_on):
@@ -1166,8 +1193,13 @@ class Mapper(InspectionAttr):
if self.inherits:
self.dispatch._update(self.inherits.dispatch)
super_extensions = set(
- chain(*[m._deprecated_extensions
- for m in self.inherits.iterate_to_root()]))
+ chain(
+ *[
+ m._deprecated_extensions
+ for m in self.inherits.iterate_to_root()
+ ]
+ )
+ )
else:
super_extensions = set()
@@ -1178,8 +1210,13 @@ class Mapper(InspectionAttr):
def _configure_listeners(self):
if self.inherits:
super_extensions = set(
- chain(*[m._deprecated_extensions
- for m in self.inherits.iterate_to_root()]))
+ chain(
+ *[
+ m._deprecated_extensions
+ for m in self.inherits.iterate_to_root()
+ ]
+ )
+ )
else:
super_extensions = set()
@@ -1206,7 +1243,8 @@ class Mapper(InspectionAttr):
raise sa_exc.InvalidRequestError(
"Class %s has no primary mapper configured. Configure "
"a primary mapper first before setting up a non primary "
- "Mapper." % self.class_)
+ "Mapper." % self.class_
+ )
self.class_manager = manager
self._identity_class = manager.mapper._identity_class
_mapper_registry[self] = True
@@ -1219,12 +1257,13 @@ class Mapper(InspectionAttr):
"Class '%s' already has a primary mapper defined. "
"Use non_primary=True to "
"create a non primary Mapper. clear_mappers() will "
- "remove *all* current mappers from all classes." %
- self.class_)
+ "remove *all* current mappers from all classes."
+ % self.class_
+ )
# else:
- # a ClassManager may already exist as
- # ClassManager.instrument_attribute() creates
- # new managers for each subclass if they don't yet exist.
+ # a ClassManager may already exist as
+ # ClassManager.instrument_attribute() creates
+ # new managers for each subclass if they don't yet exist.
_mapper_registry[self] = True
@@ -1239,33 +1278,35 @@ class Mapper(InspectionAttr):
manager.mapper = self
manager.deferred_scalar_loader = util.partial(
- loading.load_scalar_attributes, self)
+ loading.load_scalar_attributes, self
+ )
# The remaining members can be added by any mapper,
# e_name None or not.
if manager.info.get(_INSTRUMENTOR, False):
return
- event.listen(manager, 'first_init', _event_on_first_init, raw=True)
- event.listen(manager, 'init', _event_on_init, raw=True)
+ event.listen(manager, "first_init", _event_on_first_init, raw=True)
+ event.listen(manager, "init", _event_on_init, raw=True)
for key, method in util.iterate_attributes(self.class_):
- if key == '__init__' and hasattr(method, '_sa_original_init'):
+ if key == "__init__" and hasattr(method, "_sa_original_init"):
method = method._sa_original_init
if isinstance(method, types.MethodType):
method = method.im_func
if isinstance(method, types.FunctionType):
- if hasattr(method, '__sa_reconstructor__'):
+ if hasattr(method, "__sa_reconstructor__"):
self._reconstructor = method
- event.listen(manager, 'load', _event_on_load, raw=True)
- elif hasattr(method, '__sa_validators__'):
+ event.listen(manager, "load", _event_on_load, raw=True)
+ elif hasattr(method, "__sa_validators__"):
validation_opts = method.__sa_validation_opts__
for name in method.__sa_validators__:
if name in self.validators:
raise sa_exc.InvalidRequestError(
"A validation function for mapped "
- "attribute %r on mapper %s already exists." %
- (name, self))
+ "attribute %r on mapper %s already exists."
+ % (name, self)
+ )
self.validators = self.validators.union(
{name: (method, validation_opts)}
)
@@ -1283,13 +1324,15 @@ class Mapper(InspectionAttr):
self.configured = True
self._dispose_called = True
- if hasattr(self, '_configure_failed'):
+ if hasattr(self, "_configure_failed"):
del self._configure_failed
- if not self.non_primary and \
- self.class_manager is not None and \
- self.class_manager.is_mapped and \
- self.class_manager.mapper is self:
+ if (
+ not self.non_primary
+ and self.class_manager is not None
+ and self.class_manager.is_mapped
+ and self.class_manager.mapper is self
+ ):
instrumentation.unregister_class(self.class_)
def _configure_pks(self):
@@ -1298,9 +1341,9 @@ class Mapper(InspectionAttr):
self._pks_by_table = {}
self._cols_by_table = {}
- all_cols = util.column_set(chain(*[
- col.proxy_set for col in
- self._columntoproperty]))
+ all_cols = util.column_set(
+ chain(*[col.proxy_set for col in self._columntoproperty])
+ )
pk_cols = util.column_set(c for c in all_cols if c.primary_key)
@@ -1311,12 +1354,12 @@ class Mapper(InspectionAttr):
if t.primary_key and pk_cols.issuperset(t.primary_key):
# ordering is important since it determines the ordering of
# mapper.primary_key (and therefore query.get())
- self._pks_by_table[t] = \
- util.ordered_column_set(t.primary_key).\
- intersection(pk_cols)
- self._cols_by_table[t] = \
- util.ordered_column_set(t.c).\
- intersection(all_cols)
+ self._pks_by_table[t] = util.ordered_column_set(
+ t.primary_key
+ ).intersection(pk_cols)
+ self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(
+ all_cols
+ )
# if explicit PK argument sent, add those columns to the
# primary key mappings
@@ -1327,22 +1370,30 @@ class Mapper(InspectionAttr):
self._pks_by_table[k.table].add(k)
# otherwise, see that we got a full PK for the mapped table
- elif self.mapped_table not in self._pks_by_table or \
- len(self._pks_by_table[self.mapped_table]) == 0:
+ elif (
+ self.mapped_table not in self._pks_by_table
+ or len(self._pks_by_table[self.mapped_table]) == 0
+ ):
raise sa_exc.ArgumentError(
"Mapper %s could not assemble any primary "
- "key columns for mapped table '%s'" %
- (self, self.mapped_table.description))
- elif self.local_table not in self._pks_by_table and \
- isinstance(self.local_table, schema.Table):
- util.warn("Could not assemble any primary "
- "keys for locally mapped table '%s' - "
- "no rows will be persisted in this Table."
- % self.local_table.description)
-
- if self.inherits and \
- not self.concrete and \
- not self._primary_key_argument:
+ "key columns for mapped table '%s'"
+ % (self, self.mapped_table.description)
+ )
+ elif self.local_table not in self._pks_by_table and isinstance(
+ self.local_table, schema.Table
+ ):
+ util.warn(
+ "Could not assemble any primary "
+ "keys for locally mapped table '%s' - "
+ "no rows will be persisted in this Table."
+ % self.local_table.description
+ )
+
+ if (
+ self.inherits
+ and not self.concrete
+ and not self._primary_key_argument
+ ):
# if inheriting, the "primary key" for this mapper is
# that of the inheriting (unless concrete or explicit)
self.primary_key = self.inherits.primary_key
@@ -1351,19 +1402,24 @@ class Mapper(InspectionAttr):
# reduce to the minimal set of columns
if self._primary_key_argument:
primary_key = sql_util.reduce_columns(
- [self.mapped_table.corresponding_column(c) for c in
- self._primary_key_argument],
- ignore_nonexistent_tables=True)
+ [
+ self.mapped_table.corresponding_column(c)
+ for c in self._primary_key_argument
+ ],
+ ignore_nonexistent_tables=True,
+ )
else:
primary_key = sql_util.reduce_columns(
self._pks_by_table[self.mapped_table],
- ignore_nonexistent_tables=True)
+ ignore_nonexistent_tables=True,
+ )
if len(primary_key) == 0:
raise sa_exc.ArgumentError(
"Mapper %s could not assemble any primary "
- "key columns for mapped table '%s'" %
- (self, self.mapped_table.description))
+ "key columns for mapped table '%s'"
+ % (self, self.mapped_table.description)
+ )
self.primary_key = tuple(primary_key)
self._log("Identified primary key columns: %s", primary_key)
@@ -1373,9 +1429,12 @@ class Mapper(InspectionAttr):
self._readonly_props = set(
self._columntoproperty[col]
for col in self._columntoproperty
- if self._columntoproperty[col] not in self._identity_key_props and
- (not hasattr(col, 'table') or
- col.table not in self._cols_by_table))
+ if self._columntoproperty[col] not in self._identity_key_props
+ and (
+ not hasattr(col, "table")
+ or col.table not in self._cols_by_table
+ )
+ )
def _configure_properties(self):
# Column and other ClauseElement objects which are mapped
@@ -1397,9 +1456,9 @@ class Mapper(InspectionAttr):
# pull properties from the inherited mapper if any.
if self.inherits:
for key, prop in self.inherits._props.items():
- if key not in self._props and \
- not self._should_exclude(key, key, local=False,
- column=None):
+ if key not in self._props and not self._should_exclude(
+ key, key, local=False, column=None
+ ):
self._adapt_inherited_property(key, prop, False)
# create properties for each column in the mapped table,
@@ -1408,12 +1467,13 @@ class Mapper(InspectionAttr):
if column in self._columntoproperty:
continue
- column_key = (self.column_prefix or '') + column.key
+ column_key = (self.column_prefix or "") + column.key
if self._should_exclude(
- column.key, column_key,
+ column.key,
+ column_key,
local=self.local_table.c.contains_column(column),
- column=column
+ column=column,
):
continue
@@ -1423,10 +1483,9 @@ class Mapper(InspectionAttr):
if column in mapper._columntoproperty:
column_key = mapper._columntoproperty[column].key
- self._configure_property(column_key,
- column,
- init=False,
- setparent=True)
+ self._configure_property(
+ column_key, column, init=False, setparent=True
+ )
def _configure_polymorphic_setter(self, init=False):
"""Configure an attribute on the mapper representing the
@@ -1453,7 +1512,8 @@ class Mapper(InspectionAttr):
raise sa_exc.ArgumentError(
"Can't determine polymorphic_on "
"value '%s' - no attribute is "
- "mapped to this name." % self.polymorphic_on)
+ "mapped to this name." % self.polymorphic_on
+ )
if self.polymorphic_on in self._columntoproperty:
# polymorphic_on is a column that is already mapped
@@ -1462,12 +1522,14 @@ class Mapper(InspectionAttr):
elif isinstance(self.polymorphic_on, MapperProperty):
# polymorphic_on is directly a MapperProperty,
# ensure it's a ColumnProperty
- if not isinstance(self.polymorphic_on,
- properties.ColumnProperty):
+ if not isinstance(
+ self.polymorphic_on, properties.ColumnProperty
+ ):
raise sa_exc.ArgumentError(
"Only direct column-mapped "
"property or SQL expression "
- "can be passed for polymorphic_on")
+ "can be passed for polymorphic_on"
+ )
prop = self.polymorphic_on
elif not expression._is_column(self.polymorphic_on):
# polymorphic_on is not a Column and not a ColumnProperty;
@@ -1484,7 +1546,8 @@ class Mapper(InspectionAttr):
# 2. a totally standalone SQL expression which we'd
# hope is compatible with this mapper's mapped_table
col = self.mapped_table.corresponding_column(
- self.polymorphic_on)
+ self.polymorphic_on
+ )
if col is None:
# polymorphic_on doesn't derive from any
# column/expression isn't present in the mapped
@@ -1500,14 +1563,16 @@ class Mapper(InspectionAttr):
instrument = False
col = self.polymorphic_on
if isinstance(col, schema.Column) and (
- self.with_polymorphic is None or
- self.with_polymorphic[1].
- corresponding_column(col) is None):
+ self.with_polymorphic is None
+ or self.with_polymorphic[1].corresponding_column(col)
+ is None
+ ):
raise sa_exc.InvalidRequestError(
"Could not map polymorphic_on column "
"'%s' to the mapped table - polymorphic "
"loads will not function properly"
- % col.description)
+ % col.description
+ )
else:
# column/expression that polymorphic_on derives from
# is present in our mapped table
@@ -1518,16 +1583,15 @@ class Mapper(InspectionAttr):
# polymorphic_union.
# we'll make a separate ColumnProperty for it.
instrument = True
- key = getattr(col, 'key', None)
+ key = getattr(col, "key", None)
if key:
if self._should_exclude(col.key, col.key, False, col):
raise sa_exc.InvalidRequestError(
"Cannot exclude or override the "
- "discriminator column %r" %
- col.key)
+ "discriminator column %r" % col.key
+ )
else:
- self.polymorphic_on = col = \
- col.label("_sa_polymorphic_on")
+ self.polymorphic_on = col = col.label("_sa_polymorphic_on")
key = col.key
prop = properties.ColumnProperty(col, _instrument=instrument)
@@ -1551,43 +1615,51 @@ class Mapper(InspectionAttr):
if self.mapped_table is mapper.mapped_table:
self.polymorphic_on = mapper.polymorphic_on
else:
- self.polymorphic_on = \
- self.mapped_table.corresponding_column(
- mapper.polymorphic_on)
+ self.polymorphic_on = self.mapped_table.corresponding_column(
+ mapper.polymorphic_on
+ )
# we can use the parent mapper's _set_polymorphic_identity
# directly; it ensures the polymorphic_identity of the
# instance's mapper is used so is portable to subclasses.
if self.polymorphic_on is not None:
- self._set_polymorphic_identity = \
+ self._set_polymorphic_identity = (
mapper._set_polymorphic_identity
- self._validate_polymorphic_identity = \
+ )
+ self._validate_polymorphic_identity = (
mapper._validate_polymorphic_identity
+ )
else:
self._set_polymorphic_identity = None
return
if setter:
+
def _set_polymorphic_identity(state):
dict_ = state.dict
state.get_impl(polymorphic_key).set(
- state, dict_,
+ state,
+ dict_,
state.manager.mapper.polymorphic_identity,
- None)
+ None,
+ )
def _validate_polymorphic_identity(mapper, state, dict_):
- if polymorphic_key in dict_ and \
- dict_[polymorphic_key] not in \
- mapper._acceptable_polymorphic_identities:
+ if (
+ polymorphic_key in dict_
+ and dict_[polymorphic_key]
+ not in mapper._acceptable_polymorphic_identities
+ ):
util.warn_limited(
"Flushing object %s with "
"incompatible polymorphic identity %r; the "
"object may not refresh and/or load correctly",
- (state_str(state), dict_[polymorphic_key])
+ (state_str(state), dict_[polymorphic_key]),
)
self._set_polymorphic_identity = _set_polymorphic_identity
- self._validate_polymorphic_identity = \
+ self._validate_polymorphic_identity = (
_validate_polymorphic_identity
+ )
else:
self._set_polymorphic_identity = None
@@ -1628,16 +1700,20 @@ class Mapper(InspectionAttr):
# mapper and we don't map this. don't trip user-defined
# descriptors that might have side effects when invoked.
implementing_attribute = self.class_manager._get_class_attr_mro(
- key, prop)
- if implementing_attribute is prop or (isinstance(
- implementing_attribute,
- attributes.InstrumentedAttribute) and
- implementing_attribute._parententity is prop.parent
+ key, prop
+ )
+ if implementing_attribute is prop or (
+ isinstance(
+ implementing_attribute, attributes.InstrumentedAttribute
+ )
+ and implementing_attribute._parententity is prop.parent
):
self._configure_property(
key,
properties.ConcreteInheritedProperty(),
- init=init, setparent=True)
+ init=init,
+ setparent=True,
+ )
def _configure_property(self, key, prop, init=True, setparent=True):
self._log("_configure_property(%s, %s)", key, prop.__class__.__name__)
@@ -1659,7 +1735,8 @@ class Mapper(InspectionAttr):
for m2 in path:
m2.mapped_table._reset_exported()
col = self.mapped_table.corresponding_column(
- prop.columns[0])
+ prop.columns[0]
+ )
break
path.append(m)
@@ -1670,26 +1747,30 @@ class Mapper(InspectionAttr):
# column is coming in after _readonly_props was
# initialized; check for 'readonly'
- if hasattr(self, '_readonly_props') and \
- (not hasattr(col, 'table') or
- col.table not in self._cols_by_table):
+ if hasattr(self, "_readonly_props") and (
+ not hasattr(col, "table")
+ or col.table not in self._cols_by_table
+ ):
self._readonly_props.add(prop)
else:
# if column is coming in after _cols_by_table was
# initialized, ensure the col is in the right set
- if hasattr(self, '_cols_by_table') and \
- col.table in self._cols_by_table and \
- col not in self._cols_by_table[col.table]:
+ if (
+ hasattr(self, "_cols_by_table")
+ and col.table in self._cols_by_table
+ and col not in self._cols_by_table[col.table]
+ ):
self._cols_by_table[col.table].add(col)
# if this properties.ColumnProperty represents the "polymorphic
# discriminator" column, mark it. We'll need this when rendering
# columns in SELECT statements.
- if not hasattr(prop, '_is_polymorphic_discriminator'):
- prop._is_polymorphic_discriminator = \
- (col is self.polymorphic_on or
- prop.columns[0] is self.polymorphic_on)
+ if not hasattr(prop, "_is_polymorphic_discriminator"):
+ prop._is_polymorphic_discriminator = (
+ col is self.polymorphic_on
+ or prop.columns[0] is self.polymorphic_on
+ )
self.columns[key] = col
for col in prop.columns + prop._orig_columns:
@@ -1701,8 +1782,9 @@ class Mapper(InspectionAttr):
if setparent:
prop.set_parent(self, init)
- if key in self._props and \
- getattr(self._props[key], '_mapped_by_synonym', False):
+ if key in self._props and getattr(
+ self._props[key], "_mapped_by_synonym", False
+ ):
syn = self._props[key]._mapped_by_synonym
raise sa_exc.ArgumentError(
"Can't call map_column=True for synonym %r=%r, "
@@ -1710,20 +1792,22 @@ class Mapper(InspectionAttr):
"%r for column %r" % (syn, key, key, syn)
)
- if key in self._props and \
- not isinstance(prop, properties.ColumnProperty) and \
- not isinstance(
- self._props[key],
- (
- properties.ColumnProperty,
- properties.ConcreteInheritedProperty)
- ):
- util.warn("Property %s on %s being replaced with new "
- "property %s; the old property will be discarded" % (
- self._props[key],
- self,
- prop,
- ))
+ if (
+ key in self._props
+ and not isinstance(prop, properties.ColumnProperty)
+ and not isinstance(
+ self._props[key],
+ (
+ properties.ColumnProperty,
+ properties.ConcreteInheritedProperty,
+ ),
+ )
+ ):
+ util.warn(
+ "Property %s on %s being replaced with new "
+ "property %s; the old property will be discarded"
+ % (self._props[key], self, prop)
+ )
oldprop = self._props[key]
self._path_registry.pop(oldprop, None)
@@ -1753,23 +1837,29 @@ class Mapper(InspectionAttr):
if not expression._is_column(column):
raise sa_exc.ArgumentError(
"%s=%r is not an instance of MapperProperty or Column"
- % (key, prop))
+ % (key, prop)
+ )
prop = self._props.get(key, None)
if isinstance(prop, properties.ColumnProperty):
if (
- not self._inherits_equated_pairs or
- (prop.columns[0], column) not in self._inherits_equated_pairs
- ) and \
- not prop.columns[0].shares_lineage(column) and \
- prop.columns[0] is not self.version_id_col and \
- column is not self.version_id_col:
+ (
+ not self._inherits_equated_pairs
+ or (prop.columns[0], column)
+ not in self._inherits_equated_pairs
+ )
+ and not prop.columns[0].shares_lineage(column)
+ and prop.columns[0] is not self.version_id_col
+ and column is not self.version_id_col
+ ):
warn_only = prop.parent is not self
- msg = ("Implicitly combining column %s with column "
- "%s under attribute '%s'. Please configure one "
- "or more attributes for these same-named columns "
- "explicitly." % (prop.columns[-1], column, key))
+ msg = (
+ "Implicitly combining column %s with column "
+ "%s under attribute '%s'. Please configure one "
+ "or more attributes for these same-named columns "
+ "explicitly." % (prop.columns[-1], column, key)
+ )
if warn_only:
util.warn(msg)
else:
@@ -1779,11 +1869,14 @@ class Mapper(InspectionAttr):
# mapper. make a copy and append our column to it
prop = prop.copy()
prop.columns.insert(0, column)
- self._log("inserting column to existing list "
- "in properties.ColumnProperty %s" % (key))
+ self._log(
+ "inserting column to existing list "
+ "in properties.ColumnProperty %s" % (key)
+ )
return prop
- elif prop is None or isinstance(prop,
- properties.ConcreteInheritedProperty):
+ elif prop is None or isinstance(
+ prop, properties.ConcreteInheritedProperty
+ ):
mapped_column = []
for c in columns:
mc = self.mapped_table.corresponding_column(c)
@@ -1802,7 +1895,8 @@ class Mapper(InspectionAttr):
"column '%s' is not represented in the mapper's "
"table. Use the `column_property()` function to "
"force this column to be mapped as a read-only "
- "attribute." % (key, self, c))
+ "attribute." % (key, self, c)
+ )
mapped_column.append(mc)
return properties.ColumnProperty(*mapped_column)
else:
@@ -1815,8 +1909,8 @@ class Mapper(InspectionAttr):
"(including its availability as a foreign key), "
"use the 'include_properties' or 'exclude_properties' "
"mapper arguments to control specifically which table "
- "columns get mapped." %
- (key, self, column.key, prop))
+ "columns get mapped." % (key, self, column.key, prop)
+ )
def _post_configure_properties(self):
"""Call the ``init()`` method on all ``MapperProperties``
@@ -1867,34 +1961,35 @@ class Mapper(InspectionAttr):
@property
def _log_desc(self):
- return "(" + self.class_.__name__ + \
- "|" + \
- (self.local_table is not None and
- self.local_table.description or
- str(self.local_table)) +\
- (self.non_primary and
- "|non-primary" or "") + ")"
+ return (
+ "("
+ + self.class_.__name__
+ + "|"
+ + (
+ self.local_table is not None
+ and self.local_table.description
+ or str(self.local_table)
+ )
+ + (self.non_primary and "|non-primary" or "")
+ + ")"
+ )
def _log(self, msg, *args):
- self.logger.info(
- "%s " + msg, *((self._log_desc,) + args)
- )
+ self.logger.info("%s " + msg, *((self._log_desc,) + args))
def _log_debug(self, msg, *args):
- self.logger.debug(
- "%s " + msg, *((self._log_desc,) + args)
- )
+ self.logger.debug("%s " + msg, *((self._log_desc,) + args))
def __repr__(self):
- return '<Mapper at 0x%x; %s>' % (
- id(self), self.class_.__name__)
+ return "<Mapper at 0x%x; %s>" % (id(self), self.class_.__name__)
def __str__(self):
return "Mapper|%s|%s%s" % (
self.class_.__name__,
- self.local_table is not None and
- self.local_table.description or None,
- self.non_primary and "|non-primary" or ""
+ self.local_table is not None
+ and self.local_table.description
+ or None,
+ self.non_primary and "|non-primary" or "",
)
def _is_orphan(self, state):
@@ -1904,7 +1999,8 @@ class Mapper(InspectionAttr):
orphan_possible = True
has_parent = attributes.manager_of_class(cls).has_parent(
- state, key, optimistic=state.has_identity)
+ state, key, optimistic=state.has_identity
+ )
if self.legacy_is_orphan and has_parent:
return False
@@ -1930,7 +2026,8 @@ class Mapper(InspectionAttr):
return self._props[key]
except KeyError:
raise sa_exc.InvalidRequestError(
- "Mapper '%s' has no property '%s'" % (self, key))
+ "Mapper '%s' has no property '%s'" % (self, key)
+ )
def get_property_by_column(self, column):
"""Given a :class:`.Column` object, return the
@@ -1953,7 +2050,7 @@ class Mapper(InspectionAttr):
selectable, if present. This helps some more legacy-ish mappings.
"""
- if spec == '*':
+ if spec == "*":
mappers = list(self.self_and_descendants)
elif spec:
mappers = set()
@@ -1961,8 +2058,8 @@ class Mapper(InspectionAttr):
m = _class_to_mapper(m)
if not m.isa(self):
raise sa_exc.InvalidRequestError(
- "%r does not inherit from %r" %
- (m, self))
+ "%r does not inherit from %r" % (m, self)
+ )
if selectable is None:
mappers.update(m.iterate_to_root())
@@ -1973,8 +2070,9 @@ class Mapper(InspectionAttr):
mappers = []
if selectable is not None:
- tables = set(sql_util.find_tables(selectable,
- include_aliases=True))
+ tables = set(
+ sql_util.find_tables(selectable, include_aliases=True)
+ )
mappers = [m for m in mappers if m.local_table in tables]
return mappers
@@ -1991,25 +2089,26 @@ class Mapper(InspectionAttr):
if m.concrete:
raise sa_exc.InvalidRequestError(
"'with_polymorphic()' requires 'selectable' argument "
- "when concrete-inheriting mappers are used.")
+ "when concrete-inheriting mappers are used."
+ )
elif not m.single:
if innerjoin:
- from_obj = from_obj.join(m.local_table,
- m.inherit_condition)
+ from_obj = from_obj.join(
+ m.local_table, m.inherit_condition
+ )
else:
- from_obj = from_obj.outerjoin(m.local_table,
- m.inherit_condition)
+ from_obj = from_obj.outerjoin(
+ m.local_table, m.inherit_condition
+ )
return from_obj
@_memoized_configured_property
def _single_table_criterion(self):
- if self.single and \
- self.inherits and \
- self.polymorphic_on is not None:
+ if self.single and self.inherits and self.polymorphic_on is not None:
return self.polymorphic_on.in_(
- m.polymorphic_identity
- for m in self.self_and_descendants)
+ m.polymorphic_identity for m in self.self_and_descendants
+ )
else:
return None
@@ -2031,8 +2130,8 @@ class Mapper(InspectionAttr):
return selectable
else:
return self._selectable_from_mappers(
- self._mappers_from_spec(spec, selectable),
- False)
+ self._mappers_from_spec(spec, selectable), False
+ )
with_polymorphic_mappers = _with_polymorphic_mappers
"""The list of :class:`.Mapper` objects included in the
@@ -2046,9 +2145,8 @@ class Mapper(InspectionAttr):
(
table,
frozenset(
- col for col in columns
- if col.type.should_evaluate_none
- )
+ col for col in columns if col.type.should_evaluate_none
+ ),
)
for table, columns in self._cols_by_table.items()
)
@@ -2059,10 +2157,13 @@ class Mapper(InspectionAttr):
(
table,
frozenset(
- col.key for col in columns
- if not col.primary_key and
- not col.server_default and not col.default
- and not col.type.should_evaluate_none)
+ col.key
+ for col in columns
+ if not col.primary_key
+ and not col.server_default
+ and not col.default
+ and not col.type.should_evaluate_none
+ ),
)
for table, columns in self._cols_by_table.items()
)
@@ -2073,9 +2174,8 @@ class Mapper(InspectionAttr):
(
table,
dict(
- (self._columntoproperty[col].key, col)
- for col in columns
- )
+ (self._columntoproperty[col].key, col) for col in columns
+ ),
)
for table, columns in self._cols_by_table.items()
)
@@ -2083,10 +2183,7 @@ class Mapper(InspectionAttr):
@_memoized_configured_property
def _pk_keys_by_table(self):
return dict(
- (
- table,
- frozenset([col.key for col in pks])
- )
+ (table, frozenset([col.key for col in pks]))
for table, pks in self._pks_by_table.items()
)
@@ -2095,7 +2192,7 @@ class Mapper(InspectionAttr):
return dict(
(
table,
- frozenset([self._columntoproperty[col].key for col in pks])
+ frozenset([self._columntoproperty[col].key for col in pks]),
)
for table, pks in self._pks_by_table.items()
)
@@ -2105,9 +2202,13 @@ class Mapper(InspectionAttr):
return dict(
(
table,
- frozenset([
- col.key for col in columns
- if col.server_default is not None])
+ frozenset(
+ [
+ col.key
+ for col in columns
+ if col.server_default is not None
+ ]
+ ),
)
for table, columns in self._cols_by_table.items()
)
@@ -2119,11 +2220,9 @@ class Mapper(InspectionAttr):
for table, columns in self._cols_by_table.items():
for col in columns:
if (
- (
- col.server_default is not None or
- col.server_onupdate is not None
- ) and col in self._columntoproperty
- ):
+ col.server_default is not None
+ or col.server_onupdate is not None
+ ) and col in self._columntoproperty:
result.add(self._columntoproperty[col].key)
return result
@@ -2133,9 +2232,13 @@ class Mapper(InspectionAttr):
return dict(
(
table,
- frozenset([
- col.key for col in columns
- if col.server_onupdate is not None])
+ frozenset(
+ [
+ col.key
+ for col in columns
+ if col.server_onupdate is not None
+ ]
+ ),
)
for table, columns in self._cols_by_table.items()
)
@@ -2152,8 +2255,9 @@ class Mapper(InspectionAttr):
"""
return self._with_polymorphic_selectable
- def _with_polymorphic_args(self, spec=None, selectable=False,
- innerjoin=False):
+ def _with_polymorphic_args(
+ self, spec=None, selectable=False, innerjoin=False
+ ):
if self.with_polymorphic:
if not spec:
spec = self.with_polymorphic[0]
@@ -2165,13 +2269,15 @@ class Mapper(InspectionAttr):
if selectable is not None:
return mappers, selectable
else:
- return mappers, self._selectable_from_mappers(mappers,
- innerjoin)
+ return mappers, self._selectable_from_mappers(mappers, innerjoin)
@_memoized_configured_property
def _polymorphic_properties(self):
- return list(self._iterate_polymorphic_properties(
- self._with_polymorphic_mappers))
+ return list(
+ self._iterate_polymorphic_properties(
+ self._with_polymorphic_mappers
+ )
+ )
def _iterate_polymorphic_properties(self, mappers=None):
"""Return an iterator of MapperProperty objects which will render into
@@ -2187,14 +2293,17 @@ class Mapper(InspectionAttr):
# from other mappers, as these are sometimes dependent on that
# mapper's polymorphic selectable (which we don't want rendered)
for c in util.unique_list(
- chain(*[
- list(mapper.iterate_properties) for mapper in
- [self] + mappers
- ])
+ chain(
+ *[
+ list(mapper.iterate_properties)
+ for mapper in [self] + mappers
+ ]
+ )
):
- if getattr(c, '_is_polymorphic_discriminator', False) and \
- (self.polymorphic_on is None or
- c.columns[0] is not self.polymorphic_on):
+ if getattr(c, "_is_polymorphic_discriminator", False) and (
+ self.polymorphic_on is None
+ or c.columns[0] is not self.polymorphic_on
+ ):
continue
yield c
@@ -2282,7 +2391,8 @@ class Mapper(InspectionAttr):
"""
return util.ImmutableProperties(
- dict(self.class_manager._all_sqla_attributes()))
+ dict(self.class_manager._all_sqla_attributes())
+ )
@_memoized_configured_property
def synonyms(self):
@@ -2351,10 +2461,11 @@ class Mapper(InspectionAttr):
def _filter_properties(self, type_):
if Mapper._new_mappers:
configure_mappers()
- return util.ImmutableProperties(util.OrderedDict(
- (k, v) for k, v in self._props.items()
- if isinstance(v, type_)
- ))
+ return util.ImmutableProperties(
+ util.OrderedDict(
+ (k, v) for k, v in self._props.items() if isinstance(v, type_)
+ )
+ )
@_memoized_configured_property
def _get_clause(self):
@@ -2363,10 +2474,14 @@ class Mapper(InspectionAttr):
by primary key.
"""
- params = [(primary_key, sql.bindparam(None, type_=primary_key.type))
- for primary_key in self.primary_key]
- return sql.and_(*[k == v for (k, v) in params]), \
- util.column_dict(params)
+ params = [
+ (primary_key, sql.bindparam(None, type_=primary_key.type))
+ for primary_key in self.primary_key
+ ]
+ return (
+ sql.and_(*[k == v for (k, v) in params]),
+ util.column_dict(params),
+ )
@_memoized_configured_property
def _equivalent_columns(self):
@@ -2401,18 +2516,24 @@ class Mapper(InspectionAttr):
result[binary.right].add(binary.left)
else:
result[binary.right] = util.column_set((binary.left,))
+
for mapper in self.base_mapper.self_and_descendants:
if mapper.inherit_condition is not None:
visitors.traverse(
- mapper.inherit_condition, {},
- {'binary': visit_binary})
+ mapper.inherit_condition, {}, {"binary": visit_binary}
+ )
return result
def _is_userland_descriptor(self, obj):
- if isinstance(obj, (_MappedAttribute,
- instrumentation.ClassManager,
- expression.ColumnElement)):
+ if isinstance(
+ obj,
+ (
+ _MappedAttribute,
+ instrumentation.ClassManager,
+ expression.ColumnElement,
+ ),
+ ):
return False
else:
return True
@@ -2429,26 +2550,29 @@ class Mapper(InspectionAttr):
# check for class-bound attributes and/or descriptors,
# either local or from an inherited class
if local:
- if self.class_.__dict__.get(assigned_name, None) is not None \
- and self._is_userland_descriptor(
- self.class_.__dict__[assigned_name]):
+ if self.class_.__dict__.get(
+ assigned_name, None
+ ) is not None and self._is_userland_descriptor(
+ self.class_.__dict__[assigned_name]
+ ):
return True
else:
attr = self.class_manager._get_class_attr_mro(assigned_name, None)
if attr is not None and self._is_userland_descriptor(attr):
return True
- if self.include_properties is not None and \
- name not in self.include_properties and \
- (column is None or column not in self.include_properties):
+ if (
+ self.include_properties is not None
+ and name not in self.include_properties
+ and (column is None or column not in self.include_properties)
+ ):
self._log("not including property %s" % (name))
return True
- if self.exclude_properties is not None and \
- (
- name in self.exclude_properties or
- (column is not None and column in self.exclude_properties)
- ):
+ if self.exclude_properties is not None and (
+ name in self.exclude_properties
+ or (column is not None and column in self.exclude_properties)
+ ):
self._log("excluding property %s" % (name))
return True
@@ -2545,8 +2669,11 @@ class Mapper(InspectionAttr):
if adapter:
pk_cols = [adapter.columns[c] for c in pk_cols]
- return self._identity_class, \
- tuple(row[column] for column in pk_cols), identity_token
+ return (
+ self._identity_class,
+ tuple(row[column] for column in pk_cols),
+ identity_token,
+ )
def identity_key_from_primary_key(self, primary_key, identity_token=None):
"""Return an identity-map key for use in storing/retrieving an
@@ -2574,14 +2701,20 @@ class Mapper(InspectionAttr):
return self._identity_key_from_state(state, attributes.PASSIVE_OFF)
def _identity_key_from_state(
- self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET):
+ self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET
+ ):
dict_ = state.dict
manager = state.manager
- return self._identity_class, tuple([
- manager[prop.key].
- impl.get(state, dict_, passive)
- for prop in self._identity_key_props
- ]), state.identity_token
+ return (
+ self._identity_class,
+ tuple(
+ [
+ manager[prop.key].impl.get(state, dict_, passive)
+ for prop in self._identity_key_props
+ ]
+ ),
+ state.identity_token,
+ )
def primary_key_from_instance(self, instance):
"""Return the list of primary key values for the given
@@ -2595,7 +2728,8 @@ class Mapper(InspectionAttr):
"""
state = attributes.instance_state(instance)
identity_key = self._identity_key_from_state(
- state, attributes.PASSIVE_OFF)
+ state, attributes.PASSIVE_OFF
+ )
return identity_key[1]
@_memoized_configured_property
@@ -2621,8 +2755,8 @@ class Mapper(InspectionAttr):
return {prop.key for prop in self._all_pk_props}
def _get_state_attr_by_column(
- self, state, dict_, column,
- passive=attributes.PASSIVE_RETURN_NEVER_SET):
+ self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NEVER_SET
+ ):
prop = self._columntoproperty[column]
return state.manager[prop.key].impl.get(state, dict_, passive=passive)
@@ -2638,15 +2772,17 @@ class Mapper(InspectionAttr):
state = attributes.instance_state(obj)
dict_ = attributes.instance_dict(obj)
return self._get_committed_state_attr_by_column(
- state, dict_, column, passive=attributes.PASSIVE_OFF)
+ state, dict_, column, passive=attributes.PASSIVE_OFF
+ )
def _get_committed_state_attr_by_column(
- self, state, dict_, column,
- passive=attributes.PASSIVE_RETURN_NEVER_SET):
+ self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NEVER_SET
+ ):
prop = self._columntoproperty[column]
- return state.manager[prop.key].impl.\
- get_committed_value(state, dict_, passive=passive)
+ return state.manager[prop.key].impl.get_committed_value(
+ state, dict_, passive=passive
+ )
def _optimized_get_statement(self, state, attribute_names):
"""assemble a WHERE clause which retrieves a given state by primary
@@ -2660,11 +2796,15 @@ class Mapper(InspectionAttr):
"""
props = self._props
- tables = set(chain(
- *[sql_util.find_tables(c, check_columns=True)
- for key in attribute_names
- for c in props[key].columns]
- ))
+ tables = set(
+ chain(
+ *[
+ sql_util.find_tables(c, check_columns=True)
+ for key in attribute_names
+ for c in props[key].columns
+ ]
+ )
+ )
if self.base_mapper.local_table in tables:
return None
@@ -2680,22 +2820,28 @@ class Mapper(InspectionAttr):
if leftcol.table not in tables:
leftval = self._get_committed_state_attr_by_column(
- state, state.dict,
+ state,
+ state.dict,
leftcol,
- passive=attributes.PASSIVE_NO_INITIALIZE)
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
if leftval in orm_util._none_set:
raise ColumnsNotAvailable()
- binary.left = sql.bindparam(None, leftval,
- type_=binary.right.type)
+ binary.left = sql.bindparam(
+ None, leftval, type_=binary.right.type
+ )
elif rightcol.table not in tables:
rightval = self._get_committed_state_attr_by_column(
- state, state.dict,
+ state,
+ state.dict,
rightcol,
- passive=attributes.PASSIVE_NO_INITIALIZE)
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
if rightval in orm_util._none_set:
raise ColumnsNotAvailable()
- binary.right = sql.bindparam(None, rightval,
- type_=binary.right.type)
+ binary.right = sql.bindparam(
+ None, rightval, type_=binary.right.type
+ )
allconds = []
@@ -2704,15 +2850,17 @@ class Mapper(InspectionAttr):
for mapper in reversed(list(self.iterate_to_root())):
if mapper.local_table in tables:
start = True
- elif not isinstance(mapper.local_table,
- expression.TableClause):
+ elif not isinstance(
+ mapper.local_table, expression.TableClause
+ ):
return None
if start and not mapper.single:
- allconds.append(visitors.cloned_traverse(
- mapper.inherit_condition,
- {},
- {'binary': visit_binary}
- )
+ allconds.append(
+ visitors.cloned_traverse(
+ mapper.inherit_condition,
+ {},
+ {"binary": visit_binary},
+ )
)
except ColumnsNotAvailable:
return None
@@ -2730,8 +2878,7 @@ class Mapper(InspectionAttr):
for m in self.iterate_to_root():
yield m
- if m is not prev and prev not in \
- m._with_polymorphic_mappers:
+ if m is not prev and prev not in m._with_polymorphic_mappers:
break
prev = m
@@ -2743,7 +2890,7 @@ class Mapper(InspectionAttr):
# common case, takes place for all polymorphic loads
mapper = polymorphic_from
for m in self._iterate_to_target_viawpoly(mapper):
- if m.polymorphic_load == 'selectin':
+ if m.polymorphic_load == "selectin":
return m
else:
# uncommon case, selectin load options were used
@@ -2752,15 +2899,17 @@ class Mapper(InspectionAttr):
for entity in enabled_via_opt.union([polymorphic_from]):
mapper = entity.mapper
for m in self._iterate_to_target_viawpoly(mapper):
- if m.polymorphic_load == 'selectin' or \
- m in enabled_via_opt_mappers:
+ if (
+ m.polymorphic_load == "selectin"
+ or m in enabled_via_opt_mappers
+ ):
return enabled_via_opt_mappers.get(m, m)
return None
@util.dependencies(
- "sqlalchemy.ext.baked",
- "sqlalchemy.orm.strategy_options")
+ "sqlalchemy.ext.baked", "sqlalchemy.orm.strategy_options"
+ )
def _subclass_load_via_in(self, baked, strategy_options, entity):
"""Assemble a BakedQuery that can load the columns local to
this subclass as a SELECT with IN.
@@ -2768,10 +2917,8 @@ class Mapper(InspectionAttr):
"""
assert self.inherits
- polymorphic_prop = self._columntoproperty[
- self.polymorphic_on]
- keep_props = set(
- [polymorphic_prop] + self._identity_key_props)
+ polymorphic_prop = self._columntoproperty[self.polymorphic_on]
+ keep_props = set([polymorphic_prop] + self._identity_key_props)
disable_opt = strategy_options.Load(entity)
enable_opt = strategy_options.Load(entity)
@@ -2781,16 +2928,14 @@ class Mapper(InspectionAttr):
# "enable" options, to turn on the properties that we want to
# load by default (subject to options from the query)
enable_opt.set_generic_strategy(
- (prop.key, ),
- dict(prop.strategy_key)
+ (prop.key,), dict(prop.strategy_key)
)
else:
# "disable" options, to turn off the properties from the
# superclass that we *don't* want to load, applied after
# the options from the query to override them
disable_opt.set_generic_strategy(
- (prop.key, ),
- {"do_nothing": True}
+ (prop.key,), {"do_nothing": True}
)
if len(self.primary_key) > 1:
@@ -2802,22 +2947,21 @@ class Mapper(InspectionAttr):
assert entity.mapper is self
q = baked.BakedQuery(
self._compiled_cache,
- lambda session: session.query(entity).
- select_entity_from(entity.selectable)._adapt_all_clauses(),
- (self, )
+ lambda session: session.query(entity)
+ .select_entity_from(entity.selectable)
+ ._adapt_all_clauses(),
+ (self,),
)
q.spoil()
else:
q = baked.BakedQuery(
self._compiled_cache,
lambda session: session.query(self),
- (self, )
+ (self,),
)
q += lambda q: q.filter(
- in_expr.in_(
- sql.bindparam('primary_keys', expanding=True)
- )
+ in_expr.in_(sql.bindparam("primary_keys", expanding=True))
).order_by(*self.primary_key)
return q, enable_opt, disable_opt
@@ -2856,8 +3000,9 @@ class Mapper(InspectionAttr):
assert state.mapper.isa(self)
- visitables = deque([(deque(state.mapper._props.values()), prp,
- state, state.dict)])
+ visitables = deque(
+ [(deque(state.mapper._props.values()), prp, state, state.dict)]
+ )
while visitables:
iterator, item_type, parent_state, parent_dict = visitables[-1]
@@ -2869,21 +3014,28 @@ class Mapper(InspectionAttr):
prop = iterator.popleft()
if type_ not in prop.cascade:
continue
- queue = deque(prop.cascade_iterator(
- type_, parent_state, parent_dict,
- visited_states, halt_on))
+ queue = deque(
+ prop.cascade_iterator(
+ type_,
+ parent_state,
+ parent_dict,
+ visited_states,
+ halt_on,
+ )
+ )
if queue:
visitables.append((queue, mpp, None, None))
elif item_type is mpp:
- instance, instance_mapper, corresponding_state, \
- corresponding_dict = iterator.popleft()
- yield instance, instance_mapper, \
- corresponding_state, corresponding_dict
+ instance, instance_mapper, corresponding_state, corresponding_dict = (
+ iterator.popleft()
+ )
+ yield instance, instance_mapper, corresponding_state, corresponding_dict
visitables.append(
(
deque(instance_mapper._props.values()),
- prp, corresponding_state,
- corresponding_dict
+ prp,
+ corresponding_state,
+ corresponding_dict,
)
)
@@ -2903,10 +3055,9 @@ class Mapper(InspectionAttr):
for table, mapper in table_to_mapper.items():
super_ = mapper.inherits
if super_:
- extra_dependencies.extend([
- (super_table, table)
- for super_table in super_.tables
- ])
+ extra_dependencies.extend(
+ [(super_table, table) for super_table in super_.tables]
+ )
def skip(fk):
# attempt to skip dependencies that are not
@@ -2916,22 +3067,27 @@ class Mapper(InspectionAttr):
# not what we mean to sort on here.
parent = table_to_mapper.get(fk.parent.table)
dep = table_to_mapper.get(fk.column.table)
- if parent is not None and \
- dep is not None and \
- dep is not parent and \
- dep.inherit_condition is not None:
+ if (
+ parent is not None
+ and dep is not None
+ and dep is not parent
+ and dep.inherit_condition is not None
+ ):
cols = set(sql_util._find_columns(dep.inherit_condition))
if parent.inherit_condition is not None:
- cols = cols.union(sql_util._find_columns(
- parent.inherit_condition))
+ cols = cols.union(
+ sql_util._find_columns(parent.inherit_condition)
+ )
return fk.parent not in cols and fk.column not in cols
else:
return fk.parent not in cols
return False
- sorted_ = sql_util.sort_tables(table_to_mapper,
- skip_fn=skip,
- extra_dependencies=extra_dependencies)
+ sorted_ = sql_util.sort_tables(
+ table_to_mapper,
+ skip_fn=skip,
+ extra_dependencies=extra_dependencies,
+ )
ret = util.OrderedDict()
for t in sorted_:
@@ -2955,12 +3111,12 @@ class Mapper(InspectionAttr):
for table in self._sorted_tables:
cols = set(table.c)
for m in self.iterate_to_root():
- if m._inherits_equated_pairs and \
- cols.intersection(
- util.reduce(set.union,
- [l.proxy_set for l, r in
- m._inherits_equated_pairs])
- ):
+ if m._inherits_equated_pairs and cols.intersection(
+ util.reduce(
+ set.union,
+ [l.proxy_set for l, r in m._inherits_equated_pairs],
+ )
+ ):
result[table].append((m, m._inherits_equated_pairs))
return result
@@ -3034,13 +3190,14 @@ def configure_mappers():
if run_configure is EXT_SKIP:
continue
- if getattr(mapper, '_configure_failed', False):
+ if getattr(mapper, "_configure_failed", False):
e = sa_exc.InvalidRequestError(
"One or more mappers failed to initialize - "
"can't proceed with initialization of other "
"mappers. Triggering mapper: '%s'. "
"Original exception was: %s"
- % (mapper, mapper._configure_failed))
+ % (mapper, mapper._configure_failed)
+ )
e._configure_failed = mapper._configure_failed
raise e
@@ -3049,10 +3206,11 @@ def configure_mappers():
mapper._post_configure_properties()
mapper._expire_memoizations()
mapper.dispatch.mapper_configured(
- mapper, mapper.class_)
+ mapper, mapper.class_
+ )
except Exception:
exc = sys.exc_info()[1]
- if not hasattr(exc, '_configure_failed'):
+ if not hasattr(exc, "_configure_failed"):
mapper._configure_failed = exc
raise
@@ -3127,16 +3285,17 @@ def validates(*names, **kw):
:ref:`simple_validators` - usage examples for :func:`.validates`
"""
- include_removes = kw.pop('include_removes', False)
- include_backrefs = kw.pop('include_backrefs', True)
+ include_removes = kw.pop("include_removes", False)
+ include_backrefs = kw.pop("include_backrefs", True)
def wrap(fn):
fn.__sa_validators__ = names
fn.__sa_validation_opts__ = {
"include_removes": include_removes,
- "include_backrefs": include_backrefs
+ "include_backrefs": include_backrefs,
}
return fn
+
return wrap
@@ -3180,7 +3339,7 @@ def _event_on_init(state, args, kwargs):
class _ColumnMapping(dict):
"""Error reporting helper for mapper._columntoproperty."""
- __slots__ = 'mapper',
+ __slots__ = ("mapper",)
def __init__(self, mapper):
self.mapper = mapper
@@ -3190,8 +3349,10 @@ class _ColumnMapping(dict):
if prop:
raise orm_exc.UnmappedColumnError(
"Column '%s.%s' is not available, due to "
- "conflicting property '%s':%r" % (
- column.table.name, column.name, column.key, prop))
+ "conflicting property '%s':%r"
+ % (column.table.name, column.name, column.key, prop)
+ )
raise orm_exc.UnmappedColumnError(
- "No column %s is configured on mapper %s..." %
- (column, self.mapper))
+ "No column %s is configured on mapper %s..."
+ % (column, self.mapper)
+ )
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
index bb4e2eda5..f33c209cc 100644
--- a/lib/sqlalchemy/orm/path_registry.py
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -56,8 +56,7 @@ class PathRegistry(object):
is_root = False
def __eq__(self, other):
- return other is not None and \
- self.path == other.path
+ return other is not None and self.path == other.path
def set(self, attributes, key, value):
log.debug("set '%s' on path '%s' to '%s'", key, self, value)
@@ -87,11 +86,8 @@ class PathRegistry(object):
yield path[i], path[i + 1]
def contains_mapper(self, mapper):
- for path_mapper in [
- self.path[i] for i in range(0, len(self.path), 2)
- ]:
- if path_mapper.is_mapper and \
- path_mapper.isa(mapper):
+ for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]:
+ if path_mapper.is_mapper and path_mapper.isa(mapper):
return True
else:
return False
@@ -100,40 +96,49 @@ class PathRegistry(object):
return (key, self.path) in attributes
def __reduce__(self):
- return _unreduce_path, (self.serialize(), )
+ return _unreduce_path, (self.serialize(),)
def serialize(self):
path = self.path
- return list(zip(
- [m.class_ for m in [path[i] for i in range(0, len(path), 2)]],
- [path[i].key for i in range(1, len(path), 2)] + [None]
- ))
+ return list(
+ zip(
+ [m.class_ for m in [path[i] for i in range(0, len(path), 2)]],
+ [path[i].key for i in range(1, len(path), 2)] + [None],
+ )
+ )
@classmethod
def deserialize(cls, path):
if path is None:
return None
- p = tuple(chain(*[(class_mapper(mcls),
- class_mapper(mcls).attrs[key]
- if key is not None else None)
- for mcls, key in path]))
+ p = tuple(
+ chain(
+ *[
+ (
+ class_mapper(mcls),
+ class_mapper(mcls).attrs[key]
+ if key is not None
+ else None,
+ )
+ for mcls, key in path
+ ]
+ )
+ )
if p and p[-1] is None:
p = p[0:-1]
return cls.coerce(p)
@classmethod
def per_mapper(cls, mapper):
- return EntityRegistry(
- cls.root, mapper
- )
+ return EntityRegistry(cls.root, mapper)
@classmethod
def coerce(cls, raw):
return util.reduce(lambda prev, next: prev[next], raw, cls.root)
def token(self, token):
- if token.endswith(':' + _WILDCARD_TOKEN):
+ if token.endswith(":" + _WILDCARD_TOKEN):
return TokenRegistry(self, token)
elif token.endswith(":" + _DEFAULT_TOKEN):
return TokenRegistry(self.root, token)
@@ -141,12 +146,10 @@ class PathRegistry(object):
raise exc.ArgumentError("invalid token: %s" % token)
def __add__(self, other):
- return util.reduce(
- lambda prev, next: prev[next],
- other.path, self)
+ return util.reduce(lambda prev, next: prev[next], other.path, self)
def __repr__(self):
- return "%s(%r)" % (self.__class__.__name__, self.path, )
+ return "%s(%r)" % (self.__class__.__name__, self.path)
class RootRegistry(PathRegistry):
@@ -154,6 +157,7 @@ class RootRegistry(PathRegistry):
paths are maintained per-root-mapper.
"""
+
path = ()
has_entity = False
is_aliased_class = False
@@ -162,6 +166,7 @@ class RootRegistry(PathRegistry):
def __getitem__(self, entity):
return entity._path_registry
+
PathRegistry.root = RootRegistry()
@@ -194,8 +199,10 @@ class PropRegistry(PathRegistry):
if not insp.is_aliased_class or insp._use_mapper_path:
parent = parent.parent[prop.parent]
elif insp.is_aliased_class and insp.with_polymorphic_mappers:
- if prop.parent is not insp.mapper and \
- prop.parent in insp.with_polymorphic_mappers:
+ if (
+ prop.parent is not insp.mapper
+ and prop.parent in insp.with_polymorphic_mappers
+ ):
subclass_entity = parent[-1]._entity_for_mapper(prop.parent)
parent = parent.parent[subclass_entity]
@@ -205,15 +212,13 @@ class PropRegistry(PathRegistry):
self._wildcard_path_loader_key = (
"loader",
- self.parent.path + self.prop._wildcard_token
+ self.parent.path + self.prop._wildcard_token,
)
self._default_path_loader_key = self.prop._default_path_loader_key
self._loader_key = ("loader", self.path)
def __str__(self):
- return " -> ".join(
- str(elem) for elem in self.path
- )
+ return " -> ".join(str(elem) for elem in self.path)
@util.memoized_property
def has_entity(self):
@@ -235,9 +240,7 @@ class PropRegistry(PathRegistry):
if isinstance(entity, (int, slice)):
return self.path[entity]
else:
- return EntityRegistry(
- self, entity
- )
+ return EntityRegistry(self, entity)
class EntityRegistry(PathRegistry, dict):
@@ -258,6 +261,7 @@ class EntityRegistry(PathRegistry, dict):
def __bool__(self):
return True
+
__nonzero__ = __bool__
def __getitem__(self, entity):
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 7f9b7db0c..dc86a60e5 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -25,8 +25,13 @@ from . import loading
def _bulk_insert(
- mapper, mappings, session_transaction, isstates, return_defaults,
- render_nulls):
+ mapper,
+ mappings,
+ session_transaction,
+ isstates,
+ return_defaults,
+ render_nulls,
+):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
@@ -34,7 +39,8 @@ def _bulk_insert(
if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
- "not supported in bulk_insert()")
+ "not supported in bulk_insert()"
+ )
if isstates:
if return_defaults:
@@ -51,22 +57,33 @@ def _bulk_insert(
continue
records = (
- (None, state_dict, params, mapper,
- connection, value_params, has_all_pks, has_all_defaults)
- for
- state, state_dict, params, mp,
- conn, value_params, has_all_pks,
- has_all_defaults in _collect_insert_commands(table, (
- (None, mapping, mapper, connection)
- for mapping in mappings),
- bulk=True, return_defaults=return_defaults,
- render_nulls=render_nulls
+ (
+ None,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
+ for state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands(
+ table,
+ ((None, mapping, mapper, connection) for mapping in mappings),
+ bulk=True,
+ return_defaults=return_defaults,
+ render_nulls=render_nulls,
)
)
- _emit_insert_statements(base_mapper, None,
- cached_connections,
- super_mapper, table, records,
- bookkeeping=return_defaults)
+ _emit_insert_statements(
+ base_mapper,
+ None,
+ cached_connections,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=return_defaults,
+ )
if return_defaults and isstates:
identity_cls = mapper._identity_class
@@ -74,12 +91,13 @@ def _bulk_insert(
for state, dict_ in states:
state.key = (
identity_cls,
- tuple([dict_[key] for key in identity_props])
+ tuple([dict_[key] for key in identity_props]),
)
-def _bulk_update(mapper, mappings, session_transaction,
- isstates, update_changed_only):
+def _bulk_update(
+ mapper, mappings, session_transaction, isstates, update_changed_only
+):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
@@ -91,9 +109,8 @@ def _bulk_update(mapper, mappings, session_transaction,
def _changed_dict(mapper, state):
return dict(
(k, v)
- for k, v in state.dict.items() if k in state.committed_state or k
- in search_keys
-
+ for k, v in state.dict.items()
+ if k in state.committed_state or k in search_keys
)
if isstates:
@@ -107,7 +124,8 @@ def _bulk_update(mapper, mappings, session_transaction,
if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
- "not supported in bulk_update()")
+ "not supported in bulk_update()"
+ )
connection = session_transaction.connection(base_mapper)
@@ -115,21 +133,38 @@ def _bulk_update(mapper, mappings, session_transaction,
if not mapper.isa(super_mapper):
continue
- records = _collect_update_commands(None, table, (
- (None, mapping, mapper, connection,
- (mapping[mapper._version_id_prop.key]
- if mapper._version_id_prop else None))
- for mapping in mappings
- ), bulk=True)
+ records = _collect_update_commands(
+ None,
+ table,
+ (
+ (
+ None,
+ mapping,
+ mapper,
+ connection,
+ (
+ mapping[mapper._version_id_prop.key]
+ if mapper._version_id_prop
+ else None
+ ),
+ )
+ for mapping in mappings
+ ),
+ bulk=True,
+ )
- _emit_update_statements(base_mapper, None,
- cached_connections,
- super_mapper, table, records,
- bookkeeping=False)
+ _emit_update_statements(
+ base_mapper,
+ None,
+ cached_connections,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=False,
+ )
-def save_obj(
- base_mapper, states, uowtransaction, single=False):
+def save_obj(base_mapper, states, uowtransaction, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list
of objects.
@@ -150,19 +185,21 @@ def save_obj(
states_to_insert = []
cached_connections = _cached_connection_dict(base_mapper)
- for (state, dict_, mapper, connection,
- has_identity,
- row_switch, update_version_id) in _organize_states_for_save(
- base_mapper, states, uowtransaction
- ):
+ for (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ ) in _organize_states_for_save(base_mapper, states, uowtransaction):
if has_identity or row_switch:
states_to_update.append(
(state, dict_, mapper, connection, update_version_id)
)
else:
- states_to_insert.append(
- (state, dict_, mapper, connection)
- )
+ states_to_insert.append((state, dict_, mapper, connection))
for table, mapper in base_mapper._sorted_tables.items():
if table not in mapper._pks_by_table:
@@ -170,18 +207,30 @@ def save_obj(
insert = _collect_insert_commands(table, states_to_insert)
update = _collect_update_commands(
- uowtransaction, table, states_to_update)
+ uowtransaction, table, states_to_update
+ )
- _emit_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ )
- _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, insert)
+ _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ insert,
+ )
_finalize_insert_update_commands(
- base_mapper, uowtransaction,
+ base_mapper,
+ uowtransaction,
chain(
(
(state, state_dict, mapper, connection, False)
@@ -189,10 +238,9 @@ def save_obj(
),
(
(state, state_dict, mapper, connection, True)
- for state, state_dict, mapper, connection,
- update_version_id in states_to_update
- )
- )
+ for state, state_dict, mapper, connection, update_version_id in states_to_update
+ ),
+ ),
)
@@ -203,9 +251,9 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
"""
cached_connections = _cached_connection_dict(base_mapper)
- states_to_update = list(_organize_states_for_post_update(
- base_mapper,
- states, uowtransaction))
+ states_to_update = list(
+ _organize_states_for_post_update(base_mapper, states, uowtransaction)
+ )
for table, mapper in base_mapper._sorted_tables.items():
if table not in mapper._pks_by_table:
@@ -213,25 +261,32 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
update = (
(
- state, state_dict, sub_mapper, connection,
+ state,
+ state_dict,
+ sub_mapper,
+ connection,
mapper._get_committed_state_attr_by_column(
state, state_dict, mapper.version_id_col
- ) if mapper.version_id_col is not None else None
+ )
+ if mapper.version_id_col is not None
+ else None,
)
- for
- state, state_dict, sub_mapper, connection in states_to_update
+ for state, state_dict, sub_mapper, connection in states_to_update
if table in sub_mapper._pks_by_table
)
update = _collect_post_update_commands(
- base_mapper, uowtransaction,
- table, update,
- post_update_cols
+ base_mapper, uowtransaction, table, update, post_update_cols
)
- _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ _emit_post_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ )
def delete_obj(base_mapper, states, uowtransaction):
@@ -244,10 +299,9 @@ def delete_obj(base_mapper, states, uowtransaction):
cached_connections = _cached_connection_dict(base_mapper)
- states_to_delete = list(_organize_states_for_delete(
- base_mapper,
- states,
- uowtransaction))
+ states_to_delete = list(
+ _organize_states_for_delete(base_mapper, states, uowtransaction)
+ )
table_to_mapper = base_mapper._sorted_tables
@@ -258,14 +312,26 @@ def delete_obj(base_mapper, states, uowtransaction):
elif mapper.inherits and mapper.passive_deletes:
continue
- delete = _collect_delete_commands(base_mapper, uowtransaction,
- table, states_to_delete)
+ delete = _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+ )
- _emit_delete_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, delete)
+ _emit_delete_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ delete,
+ )
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_delete:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
mapper.dispatch.after_delete(mapper, connection, state)
@@ -282,8 +348,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
"""
for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction,
- states):
+ base_mapper, uowtransaction, states
+ ):
has_identity = bool(state.key)
@@ -304,25 +370,29 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
# no instance_key attached to it), and another instance
# with the same identity key already exists as persistent.
# convert to an UPDATE if so.
- if not has_identity and \
- instance_key in uowtransaction.session.identity_map:
- instance = \
- uowtransaction.session.identity_map[instance_key]
+ if (
+ not has_identity
+ and instance_key in uowtransaction.session.identity_map
+ ):
+ instance = uowtransaction.session.identity_map[instance_key]
existing = attributes.instance_state(instance)
if not uowtransaction.was_already_deleted(existing):
if not uowtransaction.is_deleted(existing):
raise orm_exc.FlushError(
"New instance %s with identity key %s conflicts "
- "with persistent instance %s" %
- (state_str(state), instance_key,
- state_str(existing)))
+ "with persistent instance %s"
+ % (state_str(state), instance_key, state_str(existing))
+ )
base_mapper._log_debug(
"detected row switch for identity %s. "
"will update %s, remove %s from "
- "transaction", instance_key,
- state_str(state), state_str(existing))
+ "transaction",
+ instance_key,
+ state_str(state),
+ state_str(existing),
+ )
# remove the "delete" flag from the existing element
uowtransaction.remove_state_actions(existing)
@@ -332,14 +402,21 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
update_version_id = mapper._get_committed_state_attr_by_column(
row_switch if row_switch else state,
row_switch.dict if row_switch else dict_,
- mapper.version_id_col)
+ mapper.version_id_col,
+ )
- yield (state, dict_, mapper, connection,
- has_identity, row_switch, update_version_id)
+ yield (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ )
-def _organize_states_for_post_update(base_mapper, states,
- uowtransaction):
+def _organize_states_for_post_update(base_mapper, states, uowtransaction):
"""Make an initial pass across a set of states for UPDATE
corresponding to post_update.
@@ -360,26 +437,28 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
"""
for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction,
- states):
+ base_mapper, uowtransaction, states
+ ):
mapper.dispatch.before_delete(mapper, connection, state)
if mapper.version_id_col is not None:
- update_version_id = \
- mapper._get_committed_state_attr_by_column(
- state, dict_,
- mapper.version_id_col)
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ )
else:
update_version_id = None
- yield (
- state, dict_, mapper, connection, update_version_id)
+ yield (state, dict_, mapper, connection, update_version_id)
def _collect_insert_commands(
- table, states_to_insert,
- bulk=False, return_defaults=False, render_nulls=False):
+ table,
+ states_to_insert,
+ bulk=False,
+ return_defaults=False,
+ render_nulls=False,
+):
"""Identify sets of values to use in INSERT statements for a
list of states.
@@ -400,10 +479,16 @@ def _collect_insert_commands(
col = propkey_to_col[propkey]
if value is None and col not in eval_none and not render_nulls:
continue
- elif not bulk and hasattr(value, '__clause_element__') or \
- isinstance(value, sql.ClauseElement):
- value_params[col.key] = value.__clause_element__() \
- if hasattr(value, '__clause_element__') else value
+ elif (
+ not bulk
+ and hasattr(value, "__clause_element__")
+ or isinstance(value, sql.ClauseElement)
+ ):
+ value_params[col.key] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
else:
params[col.key] = value
@@ -414,8 +499,11 @@ def _collect_insert_commands(
# which might be worth removing, as it should not be necessary
# and also produces confusion, given that "missing" and None
# now have distinct meanings
- for colkey in mapper._insert_cols_as_none[table].\
- difference(params).difference(value_params):
+ for colkey in (
+ mapper._insert_cols_as_none[table]
+ .difference(params)
+ .difference(value_params)
+ ):
params[colkey] = None
if not bulk or return_defaults:
@@ -424,28 +512,38 @@ def _collect_insert_commands(
has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_default_cols[table].\
- issubset(params)
+ has_all_defaults = mapper._server_default_cols[table].issubset(
+ params
+ )
else:
has_all_defaults = True
else:
has_all_defaults = has_all_pks = True
- if mapper.version_id_generator is not False \
- and mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
- params[mapper.version_id_col.key] = \
- mapper.version_id_generator(None)
+ if (
+ mapper.version_id_generator is not False
+ and mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ params[mapper.version_id_col.key] = mapper.version_id_generator(
+ None
+ )
yield (
- state, state_dict, params, mapper,
- connection, value_params, has_all_pks,
- has_all_defaults)
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
def _collect_update_commands(
- uowtransaction, table, states_to_update,
- bulk=False):
+ uowtransaction, table, states_to_update, bulk=False
+):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -457,8 +555,13 @@ def _collect_update_commands(
"""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_update:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
if table not in mapper._pks_by_table:
continue
@@ -474,36 +577,48 @@ def _collect_update_commands(
# look at mapper attribute keys for pk
params = dict(
(propkey_to_col[propkey].key, state_dict[propkey])
- for propkey in
- set(propkey_to_col).intersection(state_dict).difference(
- mapper._pk_attr_keys_by_table[table])
+ for propkey in set(propkey_to_col)
+ .intersection(state_dict)
+ .difference(mapper._pk_attr_keys_by_table[table])
)
has_all_defaults = True
else:
params = {}
for propkey in set(propkey_to_col).intersection(
- state.committed_state):
+ state.committed_state
+ ):
value = state_dict[propkey]
col = propkey_to_col[propkey]
- if hasattr(value, '__clause_element__') or \
- isinstance(value, sql.ClauseElement):
- value_params[col] = value.__clause_element__() \
- if hasattr(value, '__clause_element__') else value
+ if hasattr(value, "__clause_element__") or isinstance(
+ value, sql.ClauseElement
+ ):
+ value_params[col] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
# guard against values that generate non-__nonzero__
# objects for __eq__()
- elif state.manager[propkey].impl.is_equal(
- value, state.committed_state[propkey]) is not True:
+ elif (
+ state.manager[propkey].impl.is_equal(
+ value, state.committed_state[propkey]
+ )
+ is not True
+ ):
params[col.key] = value
if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_onupdate_default_cols[table].\
- issubset(params)
+ has_all_defaults = mapper._server_onupdate_default_cols[
+ table
+ ].issubset(params)
else:
has_all_defaults = True
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
if not bulk and not (params or value_params):
# HACK: check for history in other tables, in case the
@@ -511,10 +626,9 @@ def _collect_update_commands(
# where the version_id_col is. This logic was lost
# from 0.9 -> 1.0.0 and restored in 1.0.6.
for prop in mapper._columntoproperty.values():
- history = (
- state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE))
+ history = state.manager[prop.key].impl.get_history(
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
if history.added:
break
else:
@@ -525,8 +639,9 @@ def _collect_update_commands(
no_params = not params and not value_params
params[col._label] = update_version_id
- if (bulk or col.key not in params) and \
- mapper.version_id_generator is not False:
+ if (
+ bulk or col.key not in params
+ ) and mapper.version_id_generator is not False:
val = mapper.version_id_generator(update_version_id)
params[col.key] = val
elif mapper.version_id_generator is False and no_params:
@@ -545,9 +660,9 @@ def _collect_update_commands(
# look at mapper attribute keys for pk
pk_params = dict(
(propkey_to_col[propkey]._label, state_dict.get(propkey))
- for propkey in
- set(propkey_to_col).
- intersection(mapper._pk_attr_keys_by_table[table])
+ for propkey in set(propkey_to_col).intersection(
+ mapper._pk_attr_keys_by_table[table]
+ )
)
else:
pk_params = {}
@@ -555,12 +670,15 @@ def _collect_update_commands(
propkey = mapper._columntoproperty[col].key
history = state.manager[propkey].impl.get_history(
- state, state_dict, attributes.PASSIVE_OFF)
+ state, state_dict, attributes.PASSIVE_OFF
+ )
if history.added:
- if not history.deleted or \
- ("pk_cascaded", state, col) in \
- uowtransaction.attributes:
+ if (
+ not history.deleted
+ or ("pk_cascaded", state, col)
+ in uowtransaction.attributes
+ ):
pk_params[col._label] = history.added[0]
params.pop(col.key, None)
else:
@@ -573,24 +691,38 @@ def _collect_update_commands(
if pk_params[col._label] is None:
raise orm_exc.FlushError(
"Can't update table %s using NULL for primary "
- "key value on column %s" % (table, col))
+ "key value on column %s" % (table, col)
+ )
if params or value_params:
params.update(pk_params)
yield (
- state, state_dict, params, mapper,
- connection, value_params, has_all_defaults, has_all_pks)
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ )
-def _collect_post_update_commands(base_mapper, uowtransaction, table,
- states_to_update, post_update_cols):
+def _collect_post_update_commands(
+ base_mapper, uowtransaction, table, states_to_update, post_update_cols
+):
"""Identify sets of values to use in UPDATE statements for a
list of states within a post_update operation.
"""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_update:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
# assert table in mapper._pks_by_table
@@ -600,100 +732,128 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
for col in mapper._cols_by_table[table]:
if col in pks:
- params[col._label] = \
- mapper._get_state_attr_by_column(
- state,
- state_dict, col, passive=attributes.PASSIVE_OFF)
+ params[col._label] = mapper._get_state_attr_by_column(
+ state, state_dict, col, passive=attributes.PASSIVE_OFF
+ )
elif col in post_update_cols or col.onupdate is not None:
prop = mapper._columntoproperty[col]
history = state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE)
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
if history.added:
value = history.added[0]
params[col.key] = value
hasdata = True
if hasdata:
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
col = mapper.version_id_col
params[col._label] = update_version_id
- if bool(state.key) and col.key not in params and \
- mapper.version_id_generator is not False:
+ if (
+ bool(state.key)
+ and col.key not in params
+ and mapper.version_id_generator is not False
+ ):
val = mapper.version_id_generator(update_version_id)
params[col.key] = val
yield state, state_dict, mapper, connection, params
-def _collect_delete_commands(base_mapper, uowtransaction, table,
- states_to_delete):
+def _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+):
"""Identify values to use in DELETE statements for a list of
states to be deleted."""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_delete:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
if table not in mapper._pks_by_table:
continue
params = {}
for col in mapper._pks_by_table[table]:
- params[col.key] = \
- value = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict, col)
+ params[
+ col.key
+ ] = value = mapper._get_committed_state_attr_by_column(
+ state, state_dict, col
+ )
if value is None:
raise orm_exc.FlushError(
"Can't delete from table %s "
"using NULL for primary "
- "key value on column %s" % (table, col))
+ "key value on column %s" % (table, col)
+ )
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
params[mapper.version_id_col.key] = update_version_id
yield params, connection
-def _emit_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update,
- bookkeeping=True):
+def _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ bookkeeping=True,
+):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
- needs_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
+ clause.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
if needs_version_id:
clause.clauses.append(
- mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col
+ == sql.bindparam(
mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
+ type_=mapper.version_id_col.type,
+ )
+ )
stmt = table.update(clause)
return stmt
- cached_stmt = base_mapper._memo(('update', table), update_stmt)
-
- for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \
- records in groupby(
- update,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # set of parameter keys
- bool(rec[5]), # whether or not we have "value" parameters
- rec[6], # has_all_defaults
- rec[7] # has all pks
- )
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
+
+ for (
+ (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
+ records,
+ ) in groupby(
+ update,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # set of parameter keys
+ bool(rec[5]), # whether or not we have "value" parameters
+ rec[6], # has_all_defaults
+ rec[7], # has all pks
+ ),
):
rows = 0
records = list(records)
@@ -704,8 +864,11 @@ def _emit_update_statements(base_mapper, uowtransaction,
if not has_all_pks:
statement = statement.return_defaults()
return_defaults = True
- elif bookkeeping and not has_all_defaults and \
- mapper.base_mapper.eager_defaults:
+ elif (
+ bookkeeping
+ and not has_all_defaults
+ and mapper.base_mapper.eager_defaults
+ ):
statement = statement.return_defaults()
return_defaults = True
elif mapper.version_id_col is not None:
@@ -718,17 +881,24 @@ def _emit_update_statements(base_mapper, uowtransaction,
else connection.dialect.supports_sane_rowcount_returning
)
- assert_multirow = assert_singlerow and \
- connection.dialect.supports_sane_multi_rowcount
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
allow_multirow = has_all_defaults and not needs_version_id
if hasvalue:
- for state, state_dict, params, mapper, \
- connection, value_params, \
- has_all_defaults, has_all_pks in records:
- c = connection.execute(
- statement.values(value_params),
- params)
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = connection.execute(statement.values(value_params), params)
if bookkeeping:
_postfetch(
mapper,
@@ -738,17 +908,26 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
rows += c.rowcount
check_rowcount = assert_singlerow
else:
if not allow_multirow:
check_rowcount = assert_singlerow
- for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults, \
- has_all_pks in records:
- c = cached_connections[connection].\
- execute(statement, params)
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = cached_connections[connection].execute(
+ statement, params
+ )
# TODO: why with bookkeeping=False?
if bookkeeping:
@@ -760,24 +939,32 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
rows += c.rowcount
else:
multiparams = [rec[2] for rec in records]
check_rowcount = assert_multirow or (
- assert_singlerow and
- len(multiparams) == 1
+ assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(
+ statement, multiparams
+ )
rows += c.rowcount
- for state, state_dict, params, mapper, \
- connection, value_params, \
- has_all_defaults, has_all_pks in records:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
if bookkeeping:
_postfetch(
mapper,
@@ -787,59 +974,85 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
if check_rowcount:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(records), rows))
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description)
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
-def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, insert,
- bookkeeping=True):
+def _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ insert,
+ bookkeeping=True,
+):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
- cached_stmt = base_mapper._memo(('insert', table), table.insert)
-
- for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
- records in groupby(
- insert,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # parameter keys
- bool(rec[5]), # whether we have "value" parameters
- rec[6],
- rec[7])):
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+
+ for (
+ (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ records,
+ ) in groupby(
+ insert,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # parameter keys
+ bool(rec[5]), # whether we have "value" parameters
+ rec[6],
+ rec[7],
+ ),
+ ):
statement = cached_stmt
- if not bookkeeping or \
- (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not connection.dialect.implicit_returning
- ) and has_all_pks and not hasvalue:
+ if (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not connection.dialect.implicit_returning
+ )
+ and has_all_pks
+ and not hasvalue
+ ):
records = list(records)
multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(statement, multiparams)
if bookkeeping:
- for (state, state_dict, params, mapper_rec,
- conn, value_params, has_all_pks, has_all_defaults), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
+ for (
+ (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ),
+ last_inserted_params,
+ ) in zip(records, c.context.compiled_parameters):
if state:
_postfetch(
mapper_rec,
@@ -849,7 +1062,8 @@ def _emit_insert_statements(base_mapper, uowtransaction,
state_dict,
c,
last_inserted_params,
- value_params)
+ value_params,
+ )
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
@@ -859,24 +1073,33 @@ def _emit_insert_statements(base_mapper, uowtransaction,
elif mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
- for state, state_dict, params, mapper_rec, \
- connection, value_params, \
- has_all_pks, has_all_defaults in records:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ) in records:
if value_params:
result = connection.execute(
- statement.values(value_params),
- params)
+ statement.values(value_params), params
+ )
else:
- result = cached_connections[connection].\
- execute(statement, params)
+ result = cached_connections[connection].execute(
+ statement, params
+ )
primary_key = result.context.inserted_primary_key
if primary_key is not None:
# set primary key attributes
- for pk, col in zip(primary_key,
- mapper._pks_by_table[table]):
+ for pk, col in zip(
+ primary_key, mapper._pks_by_table[table]
+ ):
prop = mapper_rec._columntoproperty[col]
if state_dict.get(prop.key) is None:
state_dict[prop.key] = pk
@@ -890,31 +1113,39 @@ def _emit_insert_statements(base_mapper, uowtransaction,
state_dict,
result,
result.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
-def _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update):
+def _emit_post_update_statements(
+ base_mapper, uowtransaction, cached_connections, mapper, table, update
+):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_post_update_commands()."""
- needs_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
+ clause.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
if needs_version_id:
clause.clauses.append(
- mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col
+ == sql.bindparam(
mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
+ type_=mapper.version_id_col.type,
+ )
+ )
stmt = table.update(clause)
@@ -923,17 +1154,15 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
return stmt
- statement = base_mapper._memo(('post_update', table), update_stmt)
+ statement = base_mapper._memo(("post_update", table), update_stmt)
# execute each UPDATE in the order according to the original
# list of states to guarantee row access order, but
# also group them into common (connection, cols) sets
# to support executemany().
for key, records in groupby(
- update, lambda rec: (
- rec[3], # connection
- set(rec[4]), # parameter keys
- )
+ update,
+ lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
):
rows = 0
@@ -945,84 +1174,96 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
if mapper.version_id_col is None
else connection.dialect.supports_sane_rowcount_returning
)
- assert_multirow = assert_singlerow and \
- connection.dialect.supports_sane_multi_rowcount
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
allow_multirow = not needs_version_id or assert_multirow
-
if not allow_multirow:
check_rowcount = assert_singlerow
- for state, state_dict, mapper_rec, \
- connection, params in records:
- c = cached_connections[connection].\
- execute(statement, params)
+ for state, state_dict, mapper_rec, connection, params in records:
+ c = cached_connections[connection].execute(statement, params)
_postfetch_post_update(
- mapper_rec, uowtransaction, table, state, state_dict,
- c, c.context.compiled_parameters[0])
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
rows += c.rowcount
else:
multiparams = [
- params for
- state, state_dict, mapper_rec, conn, params in records]
+ params
+ for state, state_dict, mapper_rec, conn, params in records
+ ]
check_rowcount = assert_multirow or (
- assert_singlerow and
- len(multiparams) == 1
+ assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(statement, multiparams)
rows += c.rowcount
- for state, state_dict, mapper_rec, \
- connection, params in records:
+ for state, state_dict, mapper_rec, connection, params in records:
_postfetch_post_update(
- mapper_rec, uowtransaction, table, state, state_dict,
- c, c.context.compiled_parameters[0])
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
if check_rowcount:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(records), rows))
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description)
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
-def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
- mapper, table, delete):
+def _emit_delete_statements(
+ base_mapper, uowtransaction, cached_connections, mapper, table, delete
+):
"""Emit DELETE statements corresponding to value lists collected
by _collect_delete_commands()."""
- need_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ need_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def delete_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(
- col == sql.bindparam(col.key, type_=col.type))
+ col == sql.bindparam(col.key, type_=col.type)
+ )
if need_version_id:
clause.clauses.append(
- mapper.version_id_col ==
- sql.bindparam(
- mapper.version_id_col.key,
- type_=mapper.version_id_col.type
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col.key, type_=mapper.version_id_col.type
)
)
return table.delete(clause)
- statement = base_mapper._memo(('delete', table), delete_stmt)
- for connection, recs in groupby(
- delete,
- lambda rec: rec[1] # connection
- ):
+ statement = base_mapper._memo(("delete", table), delete_stmt)
+ for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
del_objects = [params for params, connection in recs]
connection = cached_connections[connection]
@@ -1049,9 +1290,10 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
else:
util.warn(
"Dialect %s does not support deleted rowcount "
- "- versioning cannot be verified." %
- connection.dialect.dialect_description,
- stacklevel=12)
+ "- versioning cannot be verified."
+ % connection.dialect.dialect_description,
+ stacklevel=12,
+ )
connection.execute(statement, del_objects)
else:
c = connection.execute(statement, del_objects)
@@ -1061,23 +1303,26 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
rows_matched = c.rowcount
- if base_mapper.confirm_deleted_rows and \
- rows_matched > -1 and expected != rows_matched:
+ if (
+ base_mapper.confirm_deleted_rows
+ and rows_matched > -1
+ and expected != rows_matched
+ ):
if only_warn:
util.warn(
"DELETE statement on table '%s' expected to "
"delete %d row(s); %d were matched. Please set "
"confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning." %
- (table.description, expected, rows_matched)
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
)
else:
raise orm_exc.StaleDataError(
"DELETE statement on table '%s' expected to "
"delete %d row(s); %d were matched. Please set "
"confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning." %
- (table.description, expected, rows_matched)
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
)
@@ -1091,13 +1336,16 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
if mapper._readonly_props:
readonly = state.unmodified_intersection(
[
- p.key for p in mapper._readonly_props
+ p.key
+ for p in mapper._readonly_props
if (
- p.expire_on_flush and
- (not p.deferred or p.key in state.dict)
- ) or (
- not p.expire_on_flush and
- not p.deferred and p.key not in state.dict
+ p.expire_on_flush
+ and (not p.deferred or p.key in state.dict)
+ )
+ or (
+ not p.expire_on_flush
+ and not p.deferred
+ and p.key not in state.dict
)
]
)
@@ -1112,11 +1360,14 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
if base_mapper.eager_defaults:
toload_now.extend(
state._unloaded_non_object.intersection(
- mapper._server_default_plus_onupdate_propkeys)
+ mapper._server_default_plus_onupdate_propkeys
+ )
)
- if mapper.version_id_col is not None and \
- mapper.version_id_generator is False:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_generator is False
+ ):
if mapper._version_id_prop.key in state.unloaded:
toload_now.extend([mapper._version_id_prop.key])
@@ -1124,8 +1375,10 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
state.key = base_mapper._identity_key_from_state(state)
loading.load_on_ident(
uowtransaction.session.query(mapper),
- state.key, refresh_state=state,
- only_load_props=toload_now)
+ state.key,
+ refresh_state=state,
+ only_load_props=toload_now,
+ )
# call after_XXX extensions
if not has_identity:
@@ -1133,23 +1386,29 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
else:
mapper.dispatch.after_update(mapper, connection, state)
- if mapper.version_id_generator is False and \
- mapper.version_id_col is not None:
+ if (
+ mapper.version_id_generator is False
+ and mapper.version_id_col is not None
+ ):
if state_dict[mapper._version_id_prop.key] is None:
raise orm_exc.FlushError(
- "Instance does not contain a non-NULL version value")
+ "Instance does not contain a non-NULL version value"
+ )
-def _postfetch_post_update(mapper, uowtransaction, table,
- state, dict_, result, params):
+def _postfetch_post_update(
+ mapper, uowtransaction, table, state, dict_, result, params
+):
if uowtransaction.is_deleted(state):
return
prefetch_cols = result.context.compiled.prefetch
postfetch_cols = result.context.compiled.postfetch
- if mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
@@ -1164,18 +1423,23 @@ def _postfetch_post_update(mapper, uowtransaction, table,
if refresh_flush and load_evt_attrs:
mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs)
+ state, uowtransaction, load_evt_attrs
+ )
if postfetch_cols:
- state._expire_attributes(state.dict,
- [mapper._columntoproperty[c].key
- for c in postfetch_cols if c in
- mapper._columntoproperty]
- )
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
-def _postfetch(mapper, uowtransaction, table,
- state, dict_, result, params, value_params):
+def _postfetch(
+ mapper, uowtransaction, table, state, dict_, result, params, value_params
+):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
state."""
@@ -1184,8 +1448,10 @@ def _postfetch(mapper, uowtransaction, table,
postfetch_cols = result.context.compiled.postfetch
returning_cols = result.context.compiled.returning
- if mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
@@ -1219,23 +1485,32 @@ def _postfetch(mapper, uowtransaction, table,
if refresh_flush and load_evt_attrs:
mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs)
+ state, uowtransaction, load_evt_attrs
+ )
if postfetch_cols:
- state._expire_attributes(state.dict,
- [mapper._columntoproperty[c].key
- for c in postfetch_cols if c in
- mapper._columntoproperty]
- )
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
for m, equated_pairs in mapper._table_to_equated[table]:
- sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates)
+ sync.populate(
+ state,
+ m,
+ state,
+ m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates,
+ )
def _postfetch_bulk_save(mapper, dict_, table):
@@ -1255,8 +1530,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
# organize individual states with the connection
# to use for update
if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
+ connection_callable = uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(base_mapper)
connection_callable = None
@@ -1275,7 +1549,8 @@ def _cached_connection_dict(base_mapper):
return util.PopulateDict(
lambda conn: conn.execution_options(
compiled_cache=base_mapper._compiled_cache
- ))
+ )
+ )
def _sort_states(states):
@@ -1287,9 +1562,12 @@ def _sort_states(states):
except TypeError as err:
raise sa_exc.InvalidRequestError(
"Could not sort objects by primary key; primary key "
- "values must be sortable in Python (was: %s)" % err)
- return sorted(pending, key=operator.attrgetter("insert_order")) + \
- persistent_sorted
+ "values must be sortable in Python (was: %s)" % err
+ )
+ return (
+ sorted(pending, key=operator.attrgetter("insert_order"))
+ + persistent_sorted
+ )
class BulkUD(object):
@@ -1302,21 +1580,22 @@ class BulkUD(object):
def _validate_query_state(self):
for attr, methname, notset, op in (
- ('_limit', 'limit()', None, operator.is_),
- ('_offset', 'offset()', None, operator.is_),
- ('_order_by', 'order_by()', False, operator.is_),
- ('_group_by', 'group_by()', False, operator.is_),
- ('_distinct', 'distinct()', False, operator.is_),
+ ("_limit", "limit()", None, operator.is_),
+ ("_offset", "offset()", None, operator.is_),
+ ("_order_by", "order_by()", False, operator.is_),
+ ("_group_by", "group_by()", False, operator.is_),
+ ("_distinct", "distinct()", False, operator.is_),
(
- '_from_obj',
- 'join(), outerjoin(), select_from(), or from_self()',
- (), operator.eq)
+ "_from_obj",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
):
if not op(getattr(self.query, attr), notset):
raise sa_exc.InvalidRequestError(
"Can't call Query.update() or Query.delete() "
- "when %s has been called" %
- (methname, )
+ "when %s has been called" % (methname,)
)
@property
@@ -1330,8 +1609,8 @@ class BulkUD(object):
except KeyError:
raise sa_exc.ArgumentError(
"Valid strategies for session synchronization "
- "are %s" % (", ".join(sorted(repr(x)
- for x in lookup))))
+ "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
+ )
else:
return klass(*arg)
@@ -1400,9 +1679,9 @@ class BulkEvaluate(BulkUD):
try:
evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
if query.whereclause is not None:
- eval_condition = evaluator_compiler.process(
- query.whereclause)
+ eval_condition = evaluator_compiler.process(query.whereclause)
else:
+
def eval_condition(obj):
return True
@@ -1411,15 +1690,20 @@ class BulkEvaluate(BulkUD):
except evaluator.UnevaluatableError as err:
raise sa_exc.InvalidRequestError(
'Could not evaluate current criteria in Python: "%s". '
- 'Specify \'fetch\' or False for the '
- 'synchronize_session parameter.' % err)
+ "Specify 'fetch' or False for the "
+ "synchronize_session parameter." % err
+ )
# TODO: detect when the where clause is a trivial primary key match
self.matched_objects = [
- obj for (cls, pk, identity_token), obj in
- query.session.identity_map.items()
- if issubclass(cls, target_cls) and
- eval_condition(obj)]
+ obj
+ for (
+ cls,
+ pk,
+ identity_token,
+ ), obj in query.session.identity_map.items()
+ if issubclass(cls, target_cls) and eval_condition(obj)
+ ]
class BulkFetch(BulkUD):
@@ -1430,11 +1714,11 @@ class BulkFetch(BulkUD):
session = query.session
context = query._compile_context()
select_stmt = context.statement.with_only_columns(
- self.primary_table.primary_key)
+ self.primary_table.primary_key
+ )
self.matched_rows = session.execute(
- select_stmt,
- mapper=self.mapper,
- params=query._params).fetchall()
+ select_stmt, mapper=self.mapper, params=query._params
+ ).fetchall()
class BulkUpdate(BulkUD):
@@ -1447,18 +1731,26 @@ class BulkUpdate(BulkUD):
@classmethod
def factory(cls, query, synchronize_session, values, update_kwargs):
- return BulkUD._factory({
- "evaluate": BulkUpdateEvaluate,
- "fetch": BulkUpdateFetch,
- False: BulkUpdate
- }, synchronize_session, query, values, update_kwargs)
+ return BulkUD._factory(
+ {
+ "evaluate": BulkUpdateEvaluate,
+ "fetch": BulkUpdateFetch,
+ False: BulkUpdate,
+ },
+ synchronize_session,
+ query,
+ values,
+ update_kwargs,
+ )
@property
def _resolved_values(self):
values = []
for k, v in (
- self.values.items() if hasattr(self.values, 'items')
- else self.values):
+ self.values.items()
+ if hasattr(self.values, "items")
+ else self.values
+ ):
if self.mapper:
if isinstance(k, util.string_types):
desc = _entity_descriptor(self.mapper, k)
@@ -1478,7 +1770,7 @@ class BulkUpdate(BulkUD):
if isinstance(k, attributes.QueryableAttribute):
values.append((k.key, v))
continue
- elif hasattr(k, '__clause_element__'):
+ elif hasattr(k, "__clause_element__"):
k = k.__clause_element__()
if self.mapper and isinstance(k, expression.ColumnElement):
@@ -1490,18 +1782,22 @@ class BulkUpdate(BulkUD):
values.append((attr.key, v))
else:
raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k)
+ "Invalid expression type: %r" % k
+ )
return values
def _do_exec(self):
values = self._resolved_values
- if not self.update_kwargs.get('preserve_parameter_order', False):
+ if not self.update_kwargs.get("preserve_parameter_order", False):
values = dict(values)
- update_stmt = sql.update(self.primary_table,
- self.context.whereclause, values,
- **self.update_kwargs)
+ update_stmt = sql.update(
+ self.primary_table,
+ self.context.whereclause,
+ values,
+ **self.update_kwargs
+ )
self._execute_stmt(update_stmt)
@@ -1518,15 +1814,18 @@ class BulkDelete(BulkUD):
@classmethod
def factory(cls, query, synchronize_session):
- return BulkUD._factory({
- "evaluate": BulkDeleteEvaluate,
- "fetch": BulkDeleteFetch,
- False: BulkDelete
- }, synchronize_session, query)
+ return BulkUD._factory(
+ {
+ "evaluate": BulkDeleteEvaluate,
+ "fetch": BulkDeleteFetch,
+ False: BulkDelete,
+ },
+ synchronize_session,
+ query,
+ )
def _do_exec(self):
- delete_stmt = sql.delete(self.primary_table,
- self.context.whereclause)
+ delete_stmt = sql.delete(self.primary_table, self.context.whereclause)
self._execute_stmt(delete_stmt)
@@ -1544,32 +1843,33 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
values = self._resolved_values_keys_as_propnames
for key, value in values:
self.value_evaluators[key] = evaluator_compiler.process(
- expression._literal_as_binds(value))
+ expression._literal_as_binds(value)
+ )
def _do_post_synchronize(self):
session = self.query.session
states = set()
evaluated_keys = list(self.value_evaluators.keys())
for obj in self.matched_objects:
- state, dict_ = attributes.instance_state(obj),\
- attributes.instance_dict(obj)
+ state, dict_ = (
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
# only evaluate unmodified attributes
- to_evaluate = state.unmodified.intersection(
- evaluated_keys)
+ to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
dict_[key] = self.value_evaluators[key](obj)
- state.manager.dispatch.refresh(
- state, None, to_evaluate)
+ state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
# expire attributes with pending changes
# (there was no autoflush, so they are overwritten)
- state._expire_attributes(dict_,
- set(evaluated_keys).
- difference(to_evaluate))
+ state._expire_attributes(
+ dict_, set(evaluated_keys).difference(to_evaluate)
+ )
states.add(state)
session._register_altered(states)
@@ -1580,8 +1880,8 @@ class BulkDeleteEvaluate(BulkEvaluate, BulkDelete):
def _do_post_synchronize(self):
self.query.session._remove_newly_deleted(
- [attributes.instance_state(obj)
- for obj in self.matched_objects])
+ [attributes.instance_state(obj) for obj in self.matched_objects]
+ )
class BulkUpdateFetch(BulkFetch, BulkUpdate):
@@ -1592,15 +1892,18 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate):
session = self.query.session
target_mapper = self.query._mapper_zero()
- states = set([
- attributes.instance_state(session.identity_map[identity_key])
- for identity_key in [
- target_mapper.identity_key_from_primary_key(
- list(primary_key))
- for primary_key in self.matched_rows
+ states = set(
+ [
+ attributes.instance_state(session.identity_map[identity_key])
+ for identity_key in [
+ target_mapper.identity_key_from_primary_key(
+ list(primary_key)
+ )
+ for primary_key in self.matched_rows
+ ]
+ if identity_key in session.identity_map
]
- if identity_key in session.identity_map
- ])
+ )
values = self._resolved_values_keys_as_propnames
attrib = set(k for k, v in values)
@@ -1622,10 +1925,13 @@ class BulkDeleteFetch(BulkFetch, BulkDelete):
# TODO: inline this and call remove_newly_deleted
# once
identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key))
+ list(primary_key)
+ )
if identity_key in session.identity_map:
session._remove_newly_deleted(
- [attributes.instance_state(
- session.identity_map[identity_key]
- )]
+ [
+ attributes.instance_state(
+ session.identity_map[identity_key]
+ )
+ ]
)
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index ca47fe7ea..a39cd8703 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -20,7 +20,7 @@ from .util import _orm_full_deannotate
from .interfaces import PropComparator, StrategizedProperty
-__all__ = ['ColumnProperty']
+__all__ = ["ColumnProperty"]
@log.class_logger
@@ -31,14 +31,27 @@ class ColumnProperty(StrategizedProperty):
"""
- strategy_wildcard_key = 'column'
+ strategy_wildcard_key = "column"
__slots__ = (
- '_orig_columns', 'columns', 'group', 'deferred',
- 'instrument', 'comparator_factory', 'descriptor', 'extension',
- 'active_history', 'expire_on_flush', 'info', 'doc',
- 'strategy_key', '_creation_order', '_is_polymorphic_discriminator',
- '_mapped_by_synonym', '_deferred_column_loader')
+ "_orig_columns",
+ "columns",
+ "group",
+ "deferred",
+ "instrument",
+ "comparator_factory",
+ "descriptor",
+ "extension",
+ "active_history",
+ "expire_on_flush",
+ "info",
+ "doc",
+ "strategy_key",
+ "_creation_order",
+ "_is_polymorphic_discriminator",
+ "_mapped_by_synonym",
+ "_deferred_column_loader",
+ )
def __init__(self, *columns, **kwargs):
r"""Provide a column-level property for use with a Mapper.
@@ -117,26 +130,28 @@ class ColumnProperty(StrategizedProperty):
"""
super(ColumnProperty, self).__init__()
self._orig_columns = [expression._labeled(c) for c in columns]
- self.columns = [expression._labeled(_orm_full_deannotate(c))
- for c in columns]
- self.group = kwargs.pop('group', None)
- self.deferred = kwargs.pop('deferred', False)
- self.instrument = kwargs.pop('_instrument', True)
- self.comparator_factory = kwargs.pop('comparator_factory',
- self.__class__.Comparator)
- self.descriptor = kwargs.pop('descriptor', None)
- self.extension = kwargs.pop('extension', None)
- self.active_history = kwargs.pop('active_history', False)
- self.expire_on_flush = kwargs.pop('expire_on_flush', True)
-
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
-
- if 'doc' in kwargs:
- self.doc = kwargs.pop('doc')
+ self.columns = [
+ expression._labeled(_orm_full_deannotate(c)) for c in columns
+ ]
+ self.group = kwargs.pop("group", None)
+ self.deferred = kwargs.pop("deferred", False)
+ self.instrument = kwargs.pop("_instrument", True)
+ self.comparator_factory = kwargs.pop(
+ "comparator_factory", self.__class__.Comparator
+ )
+ self.descriptor = kwargs.pop("descriptor", None)
+ self.extension = kwargs.pop("extension", None)
+ self.active_history = kwargs.pop("active_history", False)
+ self.expire_on_flush = kwargs.pop("expire_on_flush", True)
+
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+
+ if "doc" in kwargs:
+ self.doc = kwargs.pop("doc")
else:
for col in reversed(self.columns):
- doc = getattr(col, 'doc', None)
+ doc = getattr(col, "doc", None)
if doc is not None:
self.doc = doc
break
@@ -145,22 +160,24 @@ class ColumnProperty(StrategizedProperty):
if kwargs:
raise TypeError(
- "%s received unexpected keyword argument(s): %s" % (
- self.__class__.__name__,
- ', '.join(sorted(kwargs.keys()))))
+ "%s received unexpected keyword argument(s): %s"
+ % (self.__class__.__name__, ", ".join(sorted(kwargs.keys())))
+ )
util.set_creation_order(self)
self.strategy_key = (
("deferred", self.deferred),
- ("instrument", self.instrument)
+ ("instrument", self.instrument),
)
@util.dependencies("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
def _memoized_attr__deferred_column_loader(self, state, strategies):
return state.InstanceState._instance_level_callable_processor(
self.parent.class_manager,
- strategies.LoadDeferredColumns(self.key), self.key)
+ strategies.LoadDeferredColumns(self.key),
+ self.key,
+ )
def __clause_element__(self):
"""Allow the ColumnProperty to work in expression before it is turned
@@ -185,34 +202,50 @@ class ColumnProperty(StrategizedProperty):
self.key,
comparator=self.comparator_factory(self, mapper),
parententity=mapper,
- doc=self.doc
+ doc=self.doc,
)
def do_init(self):
super(ColumnProperty, self).do_init()
- if len(self.columns) > 1 and \
- set(self.parent.primary_key).issuperset(self.columns):
+ if len(self.columns) > 1 and set(self.parent.primary_key).issuperset(
+ self.columns
+ ):
util.warn(
- ("On mapper %s, primary key column '%s' is being combined "
- "with distinct primary key column '%s' in attribute '%s'. "
- "Use explicit properties to give each column its own mapped "
- "attribute name.") % (self.parent, self.columns[1],
- self.columns[0], self.key))
+ (
+ "On mapper %s, primary key column '%s' is being combined "
+ "with distinct primary key column '%s' in attribute '%s'. "
+ "Use explicit properties to give each column its own mapped "
+ "attribute name."
+ )
+ % (self.parent, self.columns[1], self.columns[0], self.key)
+ )
def copy(self):
return ColumnProperty(
deferred=self.deferred,
group=self.group,
active_history=self.active_history,
- *self.columns)
+ *self.columns
+ )
- def _getcommitted(self, state, dict_, column,
- passive=attributes.PASSIVE_OFF):
- return state.get_impl(self.key).\
- get_committed_value(state, dict_, passive=passive)
+ def _getcommitted(
+ self, state, dict_, column, passive=attributes.PASSIVE_OFF
+ ):
+ return state.get_impl(self.key).get_committed_value(
+ state, dict_, passive=passive
+ )
- def merge(self, session, source_state, source_dict, dest_state,
- dest_dict, load, _recursive, _resolve_conflict_map):
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
if not self.instrument:
return
elif self.key in source_dict:
@@ -225,7 +258,8 @@ class ColumnProperty(StrategizedProperty):
impl.set(dest_state, dest_dict, value, None)
elif dest_state.has_identity and self.key not in dest_dict:
dest_state._expire_attributes(
- dest_dict, [self.key], no_loader=True)
+ dest_dict, [self.key], no_loader=True
+ )
class Comparator(util.MemoizedSlots, PropComparator):
"""Produce boolean, comparison, and other operators for
@@ -246,7 +280,7 @@ class ColumnProperty(StrategizedProperty):
"""
- __slots__ = '__clause_element__', 'info'
+ __slots__ = "__clause_element__", "info"
def _memoized_method___clause_element__(self):
if self.adapter:
@@ -254,9 +288,12 @@ class ColumnProperty(StrategizedProperty):
else:
# no adapter, so we aren't aliased
# assert self._parententity is self._parentmapper
- return self.prop.columns[0]._annotate({
- "parententity": self._parententity,
- "parentmapper": self._parententity})
+ return self.prop.columns[0]._annotate(
+ {
+ "parententity": self._parententity,
+ "parentmapper": self._parententity,
+ }
+ )
def _memoized_attr_info(self):
ce = self.__clause_element__()
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index febf627b4..4a55a3247 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -22,26 +22,37 @@ database to return iterable result sets.
from itertools import chain
from . import (
- attributes, interfaces, object_mapper, persistence,
- exc as orm_exc, loading
+ attributes,
+ interfaces,
+ object_mapper,
+ persistence,
+ exc as orm_exc,
+ loading,
+)
+from .base import (
+ _entity_descriptor,
+ _is_aliased_class,
+ _is_mapped_class,
+ _orm_columns,
+ _generative,
+ InspectionAttr,
)
-from .base import _entity_descriptor, _is_aliased_class, \
- _is_mapped_class, _orm_columns, _generative, InspectionAttr
from .path_registry import PathRegistry
from .util import (
- AliasedClass, ORMAdapter, join as orm_join, with_parent, aliased,
- _entity_corresponds_to
+ AliasedClass,
+ ORMAdapter,
+ join as orm_join,
+ with_parent,
+ aliased,
+ _entity_corresponds_to,
)
from .. import sql, util, log, exc as sa_exc, inspect, inspection
from ..sql.expression import _interpret_as_from
-from ..sql import (
- util as sql_util,
- expression, visitors
-)
+from ..sql import util as sql_util, expression, visitors
from ..sql.base import ColumnCollection
from . import properties
-__all__ = ['Query', 'QueryContext', 'aliased']
+__all__ = ["Query", "QueryContext", "aliased"]
_path_registry = PathRegistry.root
@@ -192,16 +203,20 @@ class Query(object):
for entity in ent.entities:
if entity not in d:
ext_info = inspect(entity)
- if not ext_info.is_aliased_class and \
- ext_info.mapper.with_polymorphic:
- if ext_info.mapper.mapped_table not in \
- self._polymorphic_adapters:
+ if (
+ not ext_info.is_aliased_class
+ and ext_info.mapper.with_polymorphic
+ ):
+ if (
+ ext_info.mapper.mapped_table
+ not in self._polymorphic_adapters
+ ):
self._mapper_loads_polymorphically_with(
ext_info.mapper,
sql_util.ColumnAdapter(
ext_info.selectable,
- ext_info.mapper._equivalent_columns
- )
+ ext_info.mapper._equivalent_columns,
+ ),
)
aliased_adapter = None
elif ext_info.is_aliased_class:
@@ -209,10 +224,7 @@ class Query(object):
else:
aliased_adapter = None
- d[entity] = (
- ext_info,
- aliased_adapter
- )
+ d[entity] = (ext_info, aliased_adapter)
ent.setup_entity(*d[entity])
def _mapper_loads_polymorphically_with(self, mapper, adapter):
@@ -227,18 +239,21 @@ class Query(object):
for from_obj in obj:
info = inspect(from_obj)
- if hasattr(info, 'mapper') and \
- (info.is_mapper or info.is_aliased_class):
+ if hasattr(info, "mapper") and (
+ info.is_mapper or info.is_aliased_class
+ ):
self._select_from_entity = info
if set_base_alias and not info.is_aliased_class:
raise sa_exc.ArgumentError(
"A selectable (FromClause) instance is "
- "expected when the base alias is being set.")
+ "expected when the base alias is being set."
+ )
fa.append(info.selectable)
elif not info.is_selectable:
raise sa_exc.ArgumentError(
"argument is not a mapped class, mapper, "
- "aliased(), or FromClause instance.")
+ "aliased(), or FromClause instance."
+ )
else:
if isinstance(from_obj, expression.SelectBase):
from_obj = from_obj.alias()
@@ -248,16 +263,21 @@ class Query(object):
self._from_obj = tuple(fa)
- if set_base_alias and \
- len(self._from_obj) == 1 and \
- isinstance(select_from_alias, expression.Alias):
+ if (
+ set_base_alias
+ and len(self._from_obj) == 1
+ and isinstance(select_from_alias, expression.Alias)
+ ):
equivs = self.__all_equivs()
self._from_obj_alias = sql_util.ColumnAdapter(
- self._from_obj[0], equivs)
- elif set_base_alias and \
- len(self._from_obj) == 1 and \
- hasattr(info, "mapper") and \
- info.is_aliased_class:
+ self._from_obj[0], equivs
+ )
+ elif (
+ set_base_alias
+ and len(self._from_obj) == 1
+ and hasattr(info, "mapper")
+ and info.is_aliased_class
+ ):
self._from_obj_alias = info._adapter
def _reset_polymorphic_adapter(self, mapper):
@@ -268,14 +288,14 @@ class Query(object):
def _adapt_polymorphic_element(self, element):
if "parententity" in element._annotations:
- search = element._annotations['parententity']
+ search = element._annotations["parententity"]
alias = self._polymorphic_adapters.get(search, None)
if alias:
return alias.adapt_clause(element)
if isinstance(element, expression.FromClause):
search = element
- elif hasattr(element, 'table'):
+ elif hasattr(element, "table"):
search = element.table
else:
return None
@@ -287,8 +307,8 @@ class Query(object):
def _adapt_col_list(self, cols):
return [
self._adapt_clause(
- expression._literal_as_label_reference(o),
- True, True)
+ expression._literal_as_label_reference(o), True, True
+ )
for o in cols
]
@@ -312,11 +332,7 @@ class Query(object):
if as_filter and self._filter_aliases:
for fa in self._filter_aliases.visitor_iterator:
- adapters.append(
- (
- orm_only, fa.replace
- )
- )
+ adapters.append((orm_only, fa.replace))
if self._from_obj_alias:
# for the "from obj" alias, apply extra rule to the
@@ -326,16 +342,12 @@ class Query(object):
adapters.append(
(
orm_only if self._orm_only_from_obj_alias else False,
- self._from_obj_alias.replace
+ self._from_obj_alias.replace,
)
)
if self._polymorphic_adapters:
- adapters.append(
- (
- orm_only, self._adapt_polymorphic_element
- )
- )
+ adapters.append((orm_only, self._adapt_polymorphic_element))
if not adapters:
return clause
@@ -344,19 +356,17 @@ class Query(object):
for _orm_only, adapter in adapters:
# if 'orm only', look for ORM annotations
# in the element before adapting.
- if not _orm_only or \
- '_orm_adapt' in elem._annotations or \
- "parententity" in elem._annotations:
+ if (
+ not _orm_only
+ or "_orm_adapt" in elem._annotations
+ or "parententity" in elem._annotations
+ ):
e = adapter(elem)
if e is not None:
return e
- return visitors.replacement_traverse(
- clause,
- {},
- replace
- )
+ return visitors.replacement_traverse(clause, {}, replace)
def _query_entity_zero(self):
"""Return the first QueryEntity."""
@@ -371,9 +381,11 @@ class Query(object):
with the first QueryEntity, or alternatively the 'select from'
entity if specified."""
- return self._select_from_entity \
- if self._select_from_entity is not None \
+ return (
+ self._select_from_entity
+ if self._select_from_entity is not None
else self._query_entity_zero().entity_zero
+ )
@property
def _mapper_entities(self):
@@ -382,10 +394,7 @@ class Query(object):
yield ent
def _joinpoint_zero(self):
- return self._joinpoint.get(
- '_joinpoint_entity',
- self._entity_zero()
- )
+ return self._joinpoint.get("_joinpoint_entity", self._entity_zero())
def _bind_mapper(self):
ezero = self._entity_zero()
@@ -400,14 +409,15 @@ class Query(object):
if self._entities != [self._primary_entity]:
raise sa_exc.InvalidRequestError(
"%s() can only be used against "
- "a single mapped class." % methname)
+ "a single mapped class." % methname
+ )
return self._primary_entity.entity_zero
def _only_entity_zero(self, rationale=None):
if len(self._entities) > 1:
raise sa_exc.InvalidRequestError(
- rationale or
- "This operation requires a Query "
+ rationale
+ or "This operation requires a Query "
"against a single mapper."
)
return self._entity_zero()
@@ -420,7 +430,8 @@ class Query(object):
def _get_condition(self):
return self._no_criterion_condition(
- "get", order_by=False, distinct=False)
+ "get", order_by=False, distinct=False
+ )
def _get_existing_condition(self):
self._no_criterion_assertion("get", order_by=False, distinct=False)
@@ -428,14 +439,20 @@ class Query(object):
def _no_criterion_assertion(self, meth, order_by=True, distinct=True):
if not self._enable_assertions:
return
- if self._criterion is not None or \
- self._statement is not None or self._from_obj or \
- self._limit is not None or self._offset is not None or \
- self._group_by or (order_by and self._order_by) or \
- (distinct and self._distinct):
+ if (
+ self._criterion is not None
+ or self._statement is not None
+ or self._from_obj
+ or self._limit is not None
+ or self._offset is not None
+ or self._group_by
+ or (order_by and self._order_by)
+ or (distinct and self._distinct)
+ ):
raise sa_exc.InvalidRequestError(
"Query.%s() being called on a "
- "Query with existing criterion. " % meth)
+ "Query with existing criterion. " % meth
+ )
def _no_criterion_condition(self, meth, order_by=True, distinct=True):
self._no_criterion_assertion(meth, order_by, distinct)
@@ -450,7 +467,8 @@ class Query(object):
if self._order_by:
raise sa_exc.InvalidRequestError(
"Query.%s() being called on a "
- "Query with existing criterion. " % meth)
+ "Query with existing criterion. " % meth
+ )
self._no_criterion_condition(meth)
def _no_statement_condition(self, meth):
@@ -458,8 +476,12 @@ class Query(object):
return
if self._statement is not None:
raise sa_exc.InvalidRequestError(
- ("Query.%s() being called on a Query with an existing full "
- "statement - can't apply criterion.") % meth)
+ (
+ "Query.%s() being called on a Query with an existing full "
+ "statement - can't apply criterion."
+ )
+ % meth
+ )
def _no_limit_offset(self, meth):
if not self._enable_assertions:
@@ -470,15 +492,17 @@ class Query(object):
"or OFFSET applied. To modify the row-limited results of a "
" Query, call from_self() first. "
"Otherwise, call %s() before limit() or offset() "
- "are applied."
- % (meth, meth)
+ "are applied." % (meth, meth)
)
- def _get_options(self, populate_existing=None,
- version_check=None,
- only_load_props=None,
- refresh_state=None,
- identity_token=None):
+ def _get_options(
+ self,
+ populate_existing=None,
+ version_check=None,
+ only_load_props=None,
+ refresh_state=None,
+ identity_token=None,
+ ):
if populate_existing:
self._populate_existing = populate_existing
if version_check:
@@ -507,8 +531,7 @@ class Query(object):
"""
- stmt = self._compile_context(labels=self._with_labels).\
- statement
+ stmt = self._compile_context(labels=self._with_labels).statement
if self._params:
stmt = stmt.params(self._params)
@@ -602,8 +625,9 @@ class Query(object):
:meth:`.HasCTE.cte`
"""
- return self.enable_eagerloads(False).\
- statement.cte(name=name, recursive=recursive)
+ return self.enable_eagerloads(False).statement.cte(
+ name=name, recursive=recursive
+ )
def label(self, name):
"""Return the full SELECT statement represented by this
@@ -678,7 +702,8 @@ class Query(object):
"compatible with %s eager loading. Please "
"specify lazyload('*') or query.enable_eagerloads(False) in "
"order to "
- "proceed with query.yield_per()." % message)
+ "proceed with query.yield_per()." % message
+ )
@_generative()
def with_labels(self):
@@ -752,10 +777,9 @@ class Query(object):
self._current_path = path
@_generative(_no_clauseelement_condition)
- def with_polymorphic(self,
- cls_or_mappers,
- selectable=None,
- polymorphic_on=None):
+ def with_polymorphic(
+ self, cls_or_mappers, selectable=None, polymorphic_on=None
+ ):
"""Load columns for inheriting classes.
:meth:`.Query.with_polymorphic` applies transformations
@@ -783,13 +807,16 @@ class Query(object):
if not self._primary_entity:
raise sa_exc.InvalidRequestError(
- "No primary mapper set up for this Query.")
+ "No primary mapper set up for this Query."
+ )
entity = self._entities[0]._clone()
self._entities = [entity] + self._entities[1:]
- entity.set_with_polymorphic(self,
- cls_or_mappers,
- selectable=selectable,
- polymorphic_on=polymorphic_on)
+ entity.set_with_polymorphic(
+ self,
+ cls_or_mappers,
+ selectable=selectable,
+ polymorphic_on=polymorphic_on,
+ )
@_generative()
def yield_per(self, count):
@@ -858,8 +885,8 @@ class Query(object):
"""
self._yield_per = count
self._execution_options = self._execution_options.union(
- {"stream_results": True,
- "max_row_buffer": count})
+ {"stream_results": True, "max_row_buffer": count}
+ )
def get(self, ident):
"""Return an instance based on the given primary key identifier,
@@ -918,12 +945,16 @@ class Query(object):
:return: The object instance, or ``None``.
"""
- return self._get_impl(
- ident, loading.load_on_pk_identity)
-
- def _identity_lookup(self, mapper, primary_key_identity,
- identity_token=None, passive=attributes.PASSIVE_OFF,
- lazy_loaded_from=None):
+ return self._get_impl(ident, loading.load_on_pk_identity)
+
+ def _identity_lookup(
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ passive=attributes.PASSIVE_OFF,
+ lazy_loaded_from=None,
+ ):
"""Locate an object in the identity map.
Given a primary key identity, constructs an identity key and then
@@ -966,14 +997,13 @@ class Query(object):
"""
key = mapper.identity_key_from_primary_key(
- primary_key_identity, identity_token=identity_token)
- return loading.get_from_identity(
- self.session, key, passive)
+ primary_key_identity, identity_token=identity_token
+ )
+ return loading.get_from_identity(self.session, key, passive)
- def _get_impl(
- self, primary_key_identity, db_load_fn, identity_token=None):
+ def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
# convert composite types to individual args
- if hasattr(primary_key_identity, '__composite_values__'):
+ if hasattr(primary_key_identity, "__composite_values__"):
primary_key_identity = primary_key_identity.__composite_values__()
primary_key_identity = util.to_list(primary_key_identity)
@@ -983,16 +1013,19 @@ class Query(object):
if len(primary_key_identity) != len(mapper.primary_key):
raise sa_exc.InvalidRequestError(
"Incorrect number of values in identifier to formulate "
- "primary key for query.get(); primary key columns are %s" %
- ','.join("'%s'" % c for c in mapper.primary_key))
+ "primary key for query.get(); primary key columns are %s"
+ % ",".join("'%s'" % c for c in mapper.primary_key)
+ )
- if not self._populate_existing and \
- not mapper.always_refresh and \
- self._for_update_arg is None:
+ if (
+ not self._populate_existing
+ and not mapper.always_refresh
+ and self._for_update_arg is None
+ ):
instance = self._identity_lookup(
- mapper, primary_key_identity,
- identity_token=identity_token)
+ mapper, primary_key_identity, identity_token=identity_token
+ )
if instance is not None:
self._get_existing_condition()
@@ -1106,17 +1139,20 @@ class Query(object):
mapper = object_mapper(instance)
for prop in mapper.iterate_properties:
- if isinstance(prop, properties.RelationshipProperty) and \
- prop.mapper is entity_zero.mapper:
+ if (
+ isinstance(prop, properties.RelationshipProperty)
+ and prop.mapper is entity_zero.mapper
+ ):
property = prop
break
else:
raise sa_exc.InvalidRequestError(
"Could not locate a property which relates instances "
- "of class '%s' to instances of class '%s'" %
- (
+ "of class '%s' to instances of class '%s'"
+ % (
entity_zero.mapper.class_.__name__,
- instance.__class__.__name__)
+ instance.__class__.__name__,
+ )
)
return self.filter(with_parent(instance, property, entity_zero.entity))
@@ -1323,8 +1359,11 @@ class Query(object):
those being selected.
"""
- fromclause = self.with_labels().enable_eagerloads(False).\
- statement.correlate(None)
+ fromclause = (
+ self.with_labels()
+ .enable_eagerloads(False)
+ .statement.correlate(None)
+ )
q = self._from_selectable(fromclause)
q._enable_single_crit = False
q._select_from_entity = self._entity_zero()
@@ -1339,12 +1378,18 @@ class Query(object):
@_generative()
def _from_selectable(self, fromclause):
for attr in (
- '_statement', '_criterion',
- '_order_by', '_group_by',
- '_limit', '_offset',
- '_joinpath', '_joinpoint',
- '_distinct', '_having',
- '_prefixes', '_suffixes'
+ "_statement",
+ "_criterion",
+ "_order_by",
+ "_group_by",
+ "_limit",
+ "_offset",
+ "_joinpath",
+ "_joinpoint",
+ "_distinct",
+ "_having",
+ "_prefixes",
+ "_suffixes",
):
self.__dict__.pop(attr, None)
self._set_select_from([fromclause], True)
@@ -1369,6 +1414,7 @@ class Query(object):
if not q._yield_per:
q._yield_per = 10
return iter(q)
+
_values = values
def value(self, column):
@@ -1420,10 +1466,11 @@ class Query(object):
# given arg is a FROM clause
self._set_entity_selectables(self._entities[l:])
- @util.pending_deprecation("0.7",
- ":meth:`.add_column` is superseded "
- "by :meth:`.add_columns`",
- False)
+ @util.pending_deprecation(
+ "0.7",
+ ":meth:`.add_column` is superseded " "by :meth:`.add_columns`",
+ False,
+ )
def add_column(self, column):
"""Add a column expression to the list of result columns to be
returned.
@@ -1454,8 +1501,8 @@ class Query(object):
# most MapperOptions write to the '_attributes' dictionary,
# so copy that as well
self._attributes = self._attributes.copy()
- if '_unbound_load_dedupes' not in self._attributes:
- self._attributes['_unbound_load_dedupes'] = set()
+ if "_unbound_load_dedupes" not in self._attributes:
+ self._attributes["_unbound_load_dedupes"] = set()
opts = tuple(util.flatten_iterator(args))
self._with_options = self._with_options + opts
if conditional:
@@ -1487,7 +1534,7 @@ class Query(object):
return fn(self)
@_generative()
- def with_hint(self, selectable, text, dialect_name='*'):
+ def with_hint(self, selectable, text, dialect_name="*"):
"""Add an indexing or other executional context
hint for the given entity or selectable to
this :class:`.Query`.
@@ -1508,7 +1555,7 @@ class Query(object):
self._with_hints += ((selectable, text, dialect_name),)
- def with_statement_hint(self, text, dialect_name='*'):
+ def with_statement_hint(self, text, dialect_name="*"):
"""add a statement hint to this :class:`.Select`.
This method is similar to :meth:`.Select.with_hint` except that
@@ -1570,8 +1617,14 @@ class Query(object):
self._for_update_arg = LockmodeArg.parse_legacy_query(mode)
@_generative()
- def with_for_update(self, read=False, nowait=False, of=None,
- skip_locked=False, key_share=False):
+ def with_for_update(
+ self,
+ read=False,
+ nowait=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
"""return a new :class:`.Query` with the specified options for the
``FOR UPDATE`` clause.
@@ -1599,9 +1652,13 @@ class Query(object):
full argument and behavioral description.
"""
- self._for_update_arg = LockmodeArg(read=read, nowait=nowait, of=of,
- skip_locked=skip_locked,
- key_share=key_share)
+ self._for_update_arg = LockmodeArg(
+ read=read,
+ nowait=nowait,
+ of=of,
+ skip_locked=skip_locked,
+ key_share=key_share,
+ )
@_generative()
def params(self, *args, **kwargs):
@@ -1619,7 +1676,8 @@ class Query(object):
elif len(args) > 0:
raise sa_exc.ArgumentError(
"params() takes zero or one positional argument, "
- "which is a dictionary.")
+ "which is a dictionary."
+ )
self._params = self._params.copy()
self._params.update(kwargs)
@@ -1683,8 +1741,10 @@ class Query(object):
"""
- clauses = [_entity_descriptor(self._joinpoint_zero(), key) == value
- for key, value in kwargs.items()]
+ clauses = [
+ _entity_descriptor(self._joinpoint_zero(), key) == value
+ for key, value in kwargs.items()
+ ]
return self.filter(sql.and_(*clauses))
@_generative(_no_statement_condition, _no_limit_offset)
@@ -1704,7 +1764,7 @@ class Query(object):
if len(criterion) == 1:
if criterion[0] is False:
- if '_order_by' in self.__dict__:
+ if "_order_by" in self.__dict__:
self._order_by = False
return
if criterion[0] is None:
@@ -1765,11 +1825,13 @@ class Query(object):
criterion = expression._expression_literal_as_text(criterion)
- if criterion is not None and \
- not isinstance(criterion, sql.ClauseElement):
+ if criterion is not None and not isinstance(
+ criterion, sql.ClauseElement
+ ):
raise sa_exc.ArgumentError(
"having() argument must be of type "
- "sqlalchemy.sql.ClauseElement or string")
+ "sqlalchemy.sql.ClauseElement or string"
+ )
criterion = self._adapt_clause(criterion, True, True)
@@ -2122,17 +2184,23 @@ class Query(object):
SQLAlchemy versions was the primary ORM-level joining interface.
"""
- aliased, from_joinpoint, isouter, full = kwargs.pop('aliased', False),\
- kwargs.pop('from_joinpoint', False),\
- kwargs.pop('isouter', False),\
- kwargs.pop('full', False)
+ aliased, from_joinpoint, isouter, full = (
+ kwargs.pop("aliased", False),
+ kwargs.pop("from_joinpoint", False),
+ kwargs.pop("isouter", False),
+ kwargs.pop("full", False),
+ )
if kwargs:
- raise TypeError("unknown arguments: %s" %
- ', '.join(sorted(kwargs)))
- return self._join(props,
- outerjoin=isouter, full=full,
- create_aliases=aliased,
- from_joinpoint=from_joinpoint)
+ raise TypeError(
+ "unknown arguments: %s" % ", ".join(sorted(kwargs))
+ )
+ return self._join(
+ props,
+ outerjoin=isouter,
+ full=full,
+ create_aliases=aliased,
+ from_joinpoint=from_joinpoint,
+ )
def outerjoin(self, *props, **kwargs):
"""Create a left outer join against this ``Query`` object's criterion
@@ -2141,25 +2209,32 @@ class Query(object):
Usage is the same as the ``join()`` method.
"""
- aliased, from_joinpoint, full = kwargs.pop('aliased', False), \
- kwargs.pop('from_joinpoint', False), \
- kwargs.pop('full', False)
+ aliased, from_joinpoint, full = (
+ kwargs.pop("aliased", False),
+ kwargs.pop("from_joinpoint", False),
+ kwargs.pop("full", False),
+ )
if kwargs:
- raise TypeError("unknown arguments: %s" %
- ', '.join(sorted(kwargs)))
- return self._join(props,
- outerjoin=True, full=full, create_aliases=aliased,
- from_joinpoint=from_joinpoint)
+ raise TypeError(
+ "unknown arguments: %s" % ", ".join(sorted(kwargs))
+ )
+ return self._join(
+ props,
+ outerjoin=True,
+ full=full,
+ create_aliases=aliased,
+ from_joinpoint=from_joinpoint,
+ )
def _update_joinpoint(self, jp):
self._joinpoint = jp
# copy backwards to the root of the _joinpath
# dict, so that no existing dict in the path is mutated
- while 'prev' in jp:
- f, prev = jp['prev']
+ while "prev" in jp:
+ f, prev = jp["prev"]
prev = prev.copy()
prev[f] = jp
- jp['prev'] = (f, prev)
+ jp["prev"] = (f, prev)
jp = prev
self._joinpath = jp
@@ -2173,11 +2248,16 @@ class Query(object):
if not from_joinpoint:
self._reset_joinpoint()
- if len(keys) == 2 and \
- isinstance(keys[0], (expression.FromClause,
- type, AliasedClass)) and \
- isinstance(keys[1], (str, expression.ClauseElement,
- interfaces.PropComparator)):
+ if (
+ len(keys) == 2
+ and isinstance(
+ keys[0], (expression.FromClause, type, AliasedClass)
+ )
+ and isinstance(
+ keys[1],
+ (str, expression.ClauseElement, interfaces.PropComparator),
+ )
+ ):
# detect 2-arg form of join and
# convert to a tuple.
keys = (keys,)
@@ -2202,20 +2282,22 @@ class Query(object):
# is a little bit of legacy behavior still at work here
# which means they might be in either order.
if isinstance(
- arg1, (interfaces.PropComparator, util.string_types)):
+ arg1, (interfaces.PropComparator, util.string_types)
+ ):
right, onclause = arg2, arg1
else:
right, onclause = arg1, arg2
if onclause is None:
r_info = inspect(right)
- if not r_info.is_selectable and not hasattr(r_info, 'mapper'):
+ if not r_info.is_selectable and not hasattr(r_info, "mapper"):
raise sa_exc.ArgumentError(
"Expected mapped entity or "
- "selectable/table as join target")
+ "selectable/table as join target"
+ )
if isinstance(onclause, interfaces.PropComparator):
- of_type = getattr(onclause, '_of_type', None)
+ of_type = getattr(onclause, "_of_type", None)
else:
of_type = None
@@ -2234,12 +2316,13 @@ class Query(object):
# to work with the aliased=True flag, which is also something
# that probably shouldn't exist on join() due to its high
# complexity/usefulness ratio
- elif from_joinpoint and \
- isinstance(onclause, interfaces.PropComparator):
+ elif from_joinpoint and isinstance(
+ onclause, interfaces.PropComparator
+ ):
jp0 = self._joinpoint_zero()
info = inspect(jp0)
- if getattr(info, 'mapper', None) is onclause._parententity:
+ if getattr(info, "mapper", None) is onclause._parententity:
onclause = _entity_descriptor(jp0, onclause.key)
if isinstance(onclause, interfaces.PropComparator):
@@ -2256,8 +2339,7 @@ class Query(object):
alias = self._polymorphic_adapters.get(left, None)
# could be None or could be ColumnAdapter also
- if isinstance(alias, ORMAdapter) and \
- alias.mapper.isa(left):
+ if isinstance(alias, ORMAdapter) and alias.mapper.isa(left):
left = alias.aliased_class
onclause = getattr(left, onclause.key)
@@ -2278,14 +2360,15 @@ class Query(object):
# and then mutate the child, which might be
# shared by a different query object.
jp = self._joinpoint[edge].copy()
- jp['prev'] = (edge, self._joinpoint)
+ jp["prev"] = (edge, self._joinpoint)
self._update_joinpoint(jp)
# warn only on the last element of the list
if idx == len(keylist) - 1:
util.warn(
"Pathed join target %s has already "
- "been joined to; skipping" % prop)
+ "been joined to; skipping" % prop
+ )
continue
else:
# no descriptor/property given; we will need to figure out
@@ -2295,13 +2378,12 @@ class Query(object):
# figure out the final "left" and "right" sides and create an
# ORMJoin to add to our _from_obj tuple
self._join_left_to_right(
- left, right, onclause, prop, create_aliases,
- outerjoin, full
+ left, right, onclause, prop, create_aliases, outerjoin, full
)
def _join_left_to_right(
- self, left, right, onclause, prop,
- create_aliases, outerjoin, full):
+ self, left, right, onclause, prop, create_aliases, outerjoin, full
+ ):
"""given raw "left", "right", "onclause" parameters consumed from
a particular key within _join(), add a real ORMJoin object to
our _from_obj list (or augment an existing one)
@@ -2315,15 +2397,17 @@ class Query(object):
# figure out the best "left" side based on our existing froms /
# entities
assert prop is None
- left, replace_from_obj_index, use_entity_index = \
- self._join_determine_implicit_left_side(left, right, onclause)
+ left, replace_from_obj_index, use_entity_index = self._join_determine_implicit_left_side(
+ left, right, onclause
+ )
else:
# left is given via a relationship/name. Determine where in our
# "froms" list it should be spliced/appended as well as what
# existing entity it corresponds to.
assert prop is not None
- replace_from_obj_index, use_entity_index = \
- self._join_place_explicit_left_side(left)
+ replace_from_obj_index, use_entity_index = self._join_place_explicit_left_side(
+ left
+ )
# this should never happen because we would not have found a place
# to join on
@@ -2333,7 +2417,7 @@ class Query(object):
# a lot of things can be wrong with it. handle all that and
# get back the new effective "right" side
r_info, right, onclause = self._join_check_and_adapt_right_side(
- left, right, onclause, prop, create_aliases,
+ left, right, onclause, prop, create_aliases
)
if replace_from_obj_index is not None:
@@ -2342,11 +2426,18 @@ class Query(object):
left_clause = self._from_obj[replace_from_obj_index]
self._from_obj = (
- self._from_obj[:replace_from_obj_index] +
- (orm_join(
- left_clause, right,
- onclause, isouter=outerjoin, full=full), ) +
- self._from_obj[replace_from_obj_index + 1:])
+ self._from_obj[:replace_from_obj_index]
+ + (
+ orm_join(
+ left_clause,
+ right,
+ onclause,
+ isouter=outerjoin,
+ full=full,
+ ),
+ )
+ + self._from_obj[replace_from_obj_index + 1 :]
+ )
else:
# add a new element to the self._from_obj list
@@ -2358,8 +2449,8 @@ class Query(object):
self._from_obj = self._from_obj + (
orm_join(
- left_clause, right, onclause,
- isouter=outerjoin, full=full),
+ left_clause, right, onclause, isouter=outerjoin, full=full
+ ),
)
def _join_determine_implicit_left_side(self, left, right, onclause):
@@ -2388,8 +2479,8 @@ class Query(object):
# join has to connect to one of those FROMs.
indexes = sql_util.find_left_clause_to_join_from(
- self._from_obj,
- r_info.selectable, onclause)
+ self._from_obj, r_info.selectable, onclause
+ )
if len(indexes) == 1:
replace_from_obj_index = indexes[0]
@@ -2399,12 +2490,13 @@ class Query(object):
"Can't determine which FROM clause to join "
"from, there are multiple FROMS which can "
"join to this entity. Try adding an explicit ON clause "
- "to help resolve the ambiguity.")
+ "to help resolve the ambiguity."
+ )
else:
raise sa_exc.InvalidRequestError(
"Don't know how to join to %s; please use "
"an ON clause to more clearly establish the left "
- "side of this join" % (right, )
+ "side of this join" % (right,)
)
elif self._entities:
@@ -2430,7 +2522,8 @@ class Query(object):
all_clauses = list(potential.keys())
indexes = sql_util.find_left_clause_to_join_from(
- all_clauses, r_info.selectable, onclause)
+ all_clauses, r_info.selectable, onclause
+ )
if len(indexes) == 1:
use_entity_index, left = potential[all_clauses[indexes[0]]]
@@ -2439,18 +2532,20 @@ class Query(object):
"Can't determine which FROM clause to join "
"from, there are multiple FROMS which can "
"join to this entity. Try adding an explicit ON clause "
- "to help resolve the ambiguity.")
+ "to help resolve the ambiguity."
+ )
else:
raise sa_exc.InvalidRequestError(
"Don't know how to join to %s; please use "
"an ON clause to more clearly establish the left "
- "side of this join" % (right, )
+ "side of this join" % (right,)
)
else:
raise sa_exc.InvalidRequestError(
"No entities to join from; please use "
"select_from() to establish the left "
- "entity/selectable of this join")
+ "entity/selectable of this join"
+ )
return left, replace_from_obj_index, use_entity_index
@@ -2484,13 +2579,15 @@ class Query(object):
l_info = inspect(left)
if self._from_obj:
indexes = sql_util.find_left_clause_that_matches_given(
- self._from_obj, l_info.selectable)
+ self._from_obj, l_info.selectable
+ )
if len(indexes) > 1:
raise sa_exc.InvalidRequestError(
"Can't identify which entity in which to assign the "
"left side of this join. Please use a more specific "
- "ON clause.")
+ "ON clause."
+ )
# have an index, means the left side is already present in
# an existing FROM in the self._from_obj tuple
@@ -2504,8 +2601,11 @@ class Query(object):
# self._from_obj tuple. Determine if this left side matches up
# with existing mapper entities, in which case we want to apply the
# aliasing / adaptation rules present on that entity if any
- if replace_from_obj_index is None and \
- self._entities and hasattr(l_info, 'mapper'):
+ if (
+ replace_from_obj_index is None
+ and self._entities
+ and hasattr(l_info, "mapper")
+ ):
for idx, ent in enumerate(self._entities):
# TODO: should we be checking for multiple mapper entities
# matching?
@@ -2516,7 +2616,8 @@ class Query(object):
return replace_from_obj_index, use_entity_index
def _join_check_and_adapt_right_side(
- self, left, right, onclause, prop, create_aliases):
+ self, left, right, onclause, prop, create_aliases
+ ):
"""transform the "right" side of the join as well as the onclause
according to polymorphic mapping translations, aliasing on the query
or on the join, special cases where the right and left side have
@@ -2533,30 +2634,37 @@ class Query(object):
# if the target is a joined inheritance mapping,
# be more liberal about auto-aliasing.
if right_mapper and (
- right_mapper.with_polymorphic or
- isinstance(right_mapper.mapped_table, expression.Join)
+ right_mapper.with_polymorphic
+ or isinstance(right_mapper.mapped_table, expression.Join)
):
for from_obj in self._from_obj or [l_info.selectable]:
if sql_util.selectables_overlap(
- l_info.selectable, from_obj) and \
- sql_util.selectables_overlap(
- from_obj, r_info.selectable):
+ l_info.selectable, from_obj
+ ) and sql_util.selectables_overlap(
+ from_obj, r_info.selectable
+ ):
overlap = True
break
- if (overlap or not create_aliases) and \
- l_info.selectable is r_info.selectable:
+ if (
+ overlap or not create_aliases
+ ) and l_info.selectable is r_info.selectable:
raise sa_exc.InvalidRequestError(
- "Can't join table/selectable '%s' to itself" %
- l_info.selectable)
+ "Can't join table/selectable '%s' to itself"
+ % l_info.selectable
+ )
- right_mapper, right_selectable, right_is_aliased = \
- getattr(r_info, 'mapper', None), \
- r_info.selectable, \
- getattr(r_info, 'is_aliased_class', False)
+ right_mapper, right_selectable, right_is_aliased = (
+ getattr(r_info, "mapper", None),
+ r_info.selectable,
+ getattr(r_info, "is_aliased_class", False),
+ )
- if right_mapper and prop and \
- not right_mapper.common_parent(prop.mapper):
+ if (
+ right_mapper
+ and prop
+ and not right_mapper.common_parent(prop.mapper)
+ ):
raise sa_exc.InvalidRequestError(
"Join target %s does not correspond to "
"the right side of join condition %s" % (right, onclause)
@@ -2564,8 +2672,8 @@ class Query(object):
# _join_entities is used as a hint for single-table inheritance
# purposes at the moment
- if hasattr(r_info, 'mapper'):
- self._join_entities += (r_info, )
+ if hasattr(r_info, "mapper"):
+ self._join_entities += (r_info,)
if not right_mapper and prop:
right_mapper = prop.mapper
@@ -2579,12 +2687,14 @@ class Query(object):
right = self._adapt_clause(right, True, False)
if right_mapper and right is right_selectable:
- if not right_selectable.is_derived_from(
- right_mapper.mapped_table):
+ if not right_selectable.is_derived_from(right_mapper.mapped_table):
raise sa_exc.InvalidRequestError(
- "Selectable '%s' is not derived from '%s'" %
- (right_selectable.description,
- right_mapper.mapped_table.description))
+ "Selectable '%s' is not derived from '%s'"
+ % (
+ right_selectable.description,
+ right_mapper.mapped_table.description,
+ )
+ )
if isinstance(right_selectable, expression.SelectBase):
# TODO: this isn't even covered now!
@@ -2593,16 +2703,20 @@ class Query(object):
right = aliased(right_mapper, right_selectable)
- aliased_entity = right_mapper and \
- not right_is_aliased and \
- (
- right_mapper.with_polymorphic and isinstance(
- right_mapper._with_polymorphic_selectable,
- expression.Alias) or overlap
+ aliased_entity = (
+ right_mapper
+ and not right_is_aliased
+ and (
+ right_mapper.with_polymorphic
+ and isinstance(
+ right_mapper._with_polymorphic_selectable, expression.Alias
+ )
+ or overlap
# test for overlap:
# orm/inheritance/relationships.py
# SelfReferentialM2MTest
)
+ )
if not need_adapter and (create_aliases or aliased_entity):
right = aliased(right, flat=True)
@@ -2614,9 +2728,11 @@ class Query(object):
if need_adapter:
self._filter_aliases = ORMAdapter(
right,
- equivalents=right_mapper and
- right_mapper._equivalent_columns or {},
- chain_to=self._filter_aliases)
+ equivalents=right_mapper
+ and right_mapper._equivalent_columns
+ or {},
+ chain_to=self._filter_aliases,
+ )
# if the onclause is a ClauseElement, adapt it with any
# adapters that are in place right now
@@ -2631,20 +2747,21 @@ class Query(object):
self._mapper_loads_polymorphically_with(
right_mapper,
ORMAdapter(
- right,
- equivalents=right_mapper._equivalent_columns
- )
+ right, equivalents=right_mapper._equivalent_columns
+ ),
)
# if joining on a MapperProperty path,
# track the path to prevent redundant joins
if not create_aliases and prop:
- self._update_joinpoint({
- '_joinpoint_entity': right,
- 'prev': ((left, right, prop.key), self._joinpoint)
- })
+ self._update_joinpoint(
+ {
+ "_joinpoint_entity": right,
+ "prev": ((left, right, prop.key), self._joinpoint),
+ }
+ )
else:
- self._joinpoint = {'_joinpoint_entity': right}
+ self._joinpoint = {"_joinpoint_entity": right}
return right, inspect(right), onclause
@@ -2821,27 +2938,30 @@ class Query(object):
if isinstance(item, slice):
start, stop, step = util.decode_slice(item)
- if isinstance(stop, int) and \
- isinstance(start, int) and \
- stop - start <= 0:
+ if (
+ isinstance(stop, int)
+ and isinstance(start, int)
+ and stop - start <= 0
+ ):
return []
# perhaps we should execute a count() here so that we
# can still use LIMIT/OFFSET ?
- elif (isinstance(start, int) and start < 0) \
- or (isinstance(stop, int) and stop < 0):
+ elif (isinstance(start, int) and start < 0) or (
+ isinstance(stop, int) and stop < 0
+ ):
return list(self)[item]
res = self.slice(start, stop)
if step is not None:
- return list(res)[None:None:item.step]
+ return list(res)[None : None : item.step]
else:
return list(res)
else:
if item == -1:
return list(self)[-1]
else:
- return list(self[item:item + 1])[0]
+ return list(self[item : item + 1])[0]
@_generative(_no_statement_condition)
def slice(self, start, stop):
@@ -3014,12 +3134,13 @@ class Query(object):
"""
statement = expression._expression_literal_as_text(statement)
- if not isinstance(statement,
- (expression.TextClause,
- expression.SelectBase)):
+ if not isinstance(
+ statement, (expression.TextClause, expression.SelectBase)
+ ):
raise sa_exc.ArgumentError(
"from_statement accepts text(), select(), "
- "and union() objects only.")
+ "and union() objects only."
+ )
self._statement = statement
@@ -3082,7 +3203,8 @@ class Query(object):
return None
else:
raise orm_exc.MultipleResultsFound(
- "Multiple rows were found for one_or_none()")
+ "Multiple rows were found for one_or_none()"
+ )
def one(self):
"""Return exactly one result or raise an exception.
@@ -3106,7 +3228,8 @@ class Query(object):
ret = self.one_or_none()
except orm_exc.MultipleResultsFound:
raise orm_exc.MultipleResultsFound(
- "Multiple rows were found for one()")
+ "Multiple rows were found for one()"
+ )
else:
if ret is None:
raise orm_exc.NoResultFound("No row was found for one()")
@@ -3149,8 +3272,11 @@ class Query(object):
def __str__(self):
context = self._compile_context()
try:
- bind = self._get_bind_args(
- context, self.session.get_bind) if self.session else None
+ bind = (
+ self._get_bind_args(context, self.session.get_bind)
+ if self.session
+ else None
+ )
except sa_exc.UnboundExecutionError:
bind = None
return str(context.statement.compile(bind))
@@ -3163,24 +3289,22 @@ class Query(object):
def _execute_and_instances(self, querycontext):
conn = self._get_bind_args(
- querycontext,
- self._connection_from_session,
- close_with_result=True)
+ querycontext, self._connection_from_session, close_with_result=True
+ )
result = conn.execute(querycontext.statement, self._params)
return loading.instances(querycontext.query, result, querycontext)
def _execute_crud(self, stmt, mapper):
conn = self._connection_from_session(
- mapper=mapper, clause=stmt, close_with_result=True)
+ mapper=mapper, clause=stmt, close_with_result=True
+ )
return conn.execute(stmt, self._params)
def _get_bind_args(self, querycontext, fn, **kw):
return fn(
- mapper=self._bind_mapper(),
- clause=querycontext.statement,
- **kw
+ mapper=self._bind_mapper(), clause=querycontext.statement, **kw
)
@property
@@ -3225,21 +3349,23 @@ class Query(object):
return [
{
- 'name': ent._label_name,
- 'type': ent.type,
- 'aliased': getattr(insp_ent, 'is_aliased_class', False),
- 'expr': ent.expr,
- 'entity':
- getattr(insp_ent, "entity", None)
- if ent.entity_zero is not None
- and not insp_ent.is_clause_element
- else None
+ "name": ent._label_name,
+ "type": ent.type,
+ "aliased": getattr(insp_ent, "is_aliased_class", False),
+ "expr": ent.expr,
+ "entity": getattr(insp_ent, "entity", None)
+ if ent.entity_zero is not None
+ and not insp_ent.is_clause_element
+ else None,
}
for ent, insp_ent in [
(
_ent,
- (inspect(_ent.entity_zero)
- if _ent.entity_zero is not None else None)
+ (
+ inspect(_ent.entity_zero)
+ if _ent.entity_zero is not None
+ else None
+ ),
)
for _ent in self._entities
]
@@ -3290,21 +3416,23 @@ class Query(object):
@property
def _select_args(self):
return {
- 'limit': self._limit,
- 'offset': self._offset,
- 'distinct': self._distinct,
- 'prefixes': self._prefixes,
- 'suffixes': self._suffixes,
- 'group_by': self._group_by or None,
- 'having': self._having
+ "limit": self._limit,
+ "offset": self._offset,
+ "distinct": self._distinct,
+ "prefixes": self._prefixes,
+ "suffixes": self._suffixes,
+ "group_by": self._group_by or None,
+ "having": self._having,
}
@property
def _should_nest_selectable(self):
kwargs = self._select_args
- return (kwargs.get('limit') is not None or
- kwargs.get('offset') is not None or
- kwargs.get('distinct', False))
+ return (
+ kwargs.get("limit") is not None
+ or kwargs.get("offset") is not None
+ or kwargs.get("distinct", False)
+ )
def exists(self):
"""A convenience method that turns a query into an EXISTS subquery
@@ -3343,9 +3471,12 @@ class Query(object):
# omitting the FROM clause from a query(X) (#2818);
# .with_only_columns() after we have a core select() so that
# we get just "SELECT 1" without any entities.
- return sql.exists(self.enable_eagerloads(False).add_columns('1').
- with_labels().
- statement.with_only_columns([1]))
+ return sql.exists(
+ self.enable_eagerloads(False)
+ .add_columns("1")
+ .with_labels()
+ .statement.with_only_columns([1])
+ )
def count(self):
r"""Return a count of rows this Query would return.
@@ -3384,10 +3515,10 @@ class Query(object):
session.query(func.count(distinct(User.name)))
"""
- col = sql.func.count(sql.literal_column('*'))
+ col = sql.func.count(sql.literal_column("*"))
return self.from_self(col).scalar()
- def delete(self, synchronize_session='evaluate'):
+ def delete(self, synchronize_session="evaluate"):
r"""Perform a bulk delete query.
Deletes rows matched by this query from the database.
@@ -3506,12 +3637,11 @@ class Query(object):
"""
- delete_op = persistence.BulkDelete.factory(
- self, synchronize_session)
+ delete_op = persistence.BulkDelete.factory(self, synchronize_session)
delete_op.exec_()
return delete_op.rowcount
- def update(self, values, synchronize_session='evaluate', update_args=None):
+ def update(self, values, synchronize_session="evaluate", update_args=None):
r"""Perform a bulk update query.
Updates rows matched by this query in the database.
@@ -3640,7 +3770,8 @@ class Query(object):
update_args = update_args or {}
update_op = persistence.BulkUpdate.factory(
- self, synchronize_session, values, update_args)
+ self, synchronize_session, values, update_args
+ )
update_op.exec_()
return update_op.rowcount
@@ -3682,11 +3813,12 @@ class Query(object):
raise sa_exc.InvalidRequestError(
"No column-based properties specified for "
"refresh operation. Use session.expire() "
- "to reload collections and related items.")
+ "to reload collections and related items."
+ )
else:
raise sa_exc.InvalidRequestError(
- "Query contains no columns with which to "
- "SELECT from.")
+ "Query contains no columns with which to " "SELECT from."
+ )
if context.multi_row_eager_loaders and self._should_nest_selectable:
context.statement = self._compound_eager_statement(context)
@@ -3701,11 +3833,9 @@ class Query(object):
# then append eager joins onto that
if context.order_by:
- order_by_col_expr = \
- sql_util.expand_column_list_from_order_by(
- context.primary_columns,
- context.order_by
- )
+ order_by_col_expr = sql_util.expand_column_list_from_order_by(
+ context.primary_columns, context.order_by
+ )
else:
context.order_by = None
order_by_col_expr = []
@@ -3738,15 +3868,17 @@ class Query(object):
context.adapter = sql_util.ColumnAdapter(inner, equivs)
statement = sql.select(
- [inner] + context.secondary_columns,
- use_labels=context.labels)
+ [inner] + context.secondary_columns, use_labels=context.labels
+ )
# Oracle however does not allow FOR UPDATE on the subquery,
# and the Oracle dialect ignores it, plus for PostgreSQL, MySQL
# we expect that all elements of the row are locked, so also put it
# on the outside (except in the case of PG when OF is used)
- if context._for_update_arg is not None and \
- context._for_update_arg.of is None:
+ if (
+ context._for_update_arg is not None
+ and context._for_update_arg.of is None
+ ):
statement._for_update_arg = context._for_update_arg
from_clause = inner
@@ -3755,16 +3887,14 @@ class Query(object):
# giving us a marker as to where the "splice point" of
# the join should be
from_clause = sql_util.splice_joins(
- from_clause,
- eager_join, eager_join.stop_on)
+ from_clause, eager_join, eager_join.stop_on
+ )
statement.append_from(from_clause)
if context.order_by:
statement.append_order_by(
- *context.adapter.copy_and_process(
- context.order_by
- )
+ *context.adapter.copy_and_process(context.order_by)
)
statement.append_order_by(*context.eager_order_by)
@@ -3775,16 +3905,13 @@ class Query(object):
context.order_by = None
if self._distinct is True and context.order_by:
- context.primary_columns += \
- sql_util.expand_column_list_from_order_by(
- context.primary_columns,
- context.order_by
- )
+ context.primary_columns += sql_util.expand_column_list_from_order_by(
+ context.primary_columns, context.order_by
+ )
context.froms += tuple(context.eager_joins.values())
statement = sql.select(
- context.primary_columns +
- context.secondary_columns,
+ context.primary_columns + context.secondary_columns,
context.whereclause,
from_obj=context.froms,
use_labels=context.labels,
@@ -3815,8 +3942,10 @@ class Query(object):
"""
search = set(self._mapper_adapter_map.values())
- if self._select_from_entity and \
- self._select_from_entity not in self._mapper_adapter_map:
+ if (
+ self._select_from_entity
+ and self._select_from_entity not in self._mapper_adapter_map
+ ):
insp = inspect(self._select_from_entity)
if insp.is_aliased_class:
adapter = insp._adapter
@@ -3833,8 +3962,8 @@ class Query(object):
single_crit = adapter.traverse(single_crit)
single_crit = self._adapt_clause(single_crit, False, False)
context.whereclause = sql.and_(
- sql.True_._ifnone(context.whereclause),
- single_crit)
+ sql.True_._ifnone(context.whereclause), single_crit
+ )
from ..sql.selectable import ForUpdateArg
@@ -3856,7 +3985,8 @@ class LockmodeArg(ForUpdateArg):
read = False
else:
raise sa_exc.ArgumentError(
- "Unknown with_lockmode argument: %r" % mode)
+ "Unknown with_lockmode argument: %r" % mode
+ )
return LockmodeArg(read=read, nowait=nowait)
@@ -3867,8 +3997,9 @@ class _QueryEntity(object):
def __new__(cls, *args, **kwargs):
if cls is _QueryEntity:
entity = args[1]
- if not isinstance(entity, util.string_types) and \
- _is_mapped_class(entity):
+ if not isinstance(entity, util.string_types) and _is_mapped_class(
+ entity
+ ):
cls = _MapperEntity
elif isinstance(entity, Bundle):
cls = _BundleEntity
@@ -3903,8 +4034,7 @@ class _MapperEntity(_QueryEntity):
self.selectable = ext_info.selectable
self.is_aliased_class = ext_info.is_aliased_class
self._with_polymorphic = ext_info.with_polymorphic_mappers
- self._polymorphic_discriminator = \
- ext_info.polymorphic_on
+ self._polymorphic_discriminator = ext_info.polymorphic_on
self.entity_zero = ext_info
if ext_info.is_aliased_class:
self._label_name = self.entity_zero.name
@@ -3912,8 +4042,9 @@ class _MapperEntity(_QueryEntity):
self._label_name = self.mapper.class_.__name__
self.path = self.entity_zero._path_registry
- def set_with_polymorphic(self, query, cls_or_mappers,
- selectable, polymorphic_on):
+ def set_with_polymorphic(
+ self, query, cls_or_mappers, selectable, polymorphic_on
+ ):
"""Receive an update from a call to query.with_polymorphic().
Note the newer style of using a free standing with_polymporphic()
@@ -3924,8 +4055,7 @@ class _MapperEntity(_QueryEntity):
if self.is_aliased_class:
# TODO: invalidrequest ?
raise NotImplementedError(
- "Can't use with_polymorphic() against "
- "an Aliased object"
+ "Can't use with_polymorphic() against " "an Aliased object"
)
if cls_or_mappers is None:
@@ -3933,14 +4063,16 @@ class _MapperEntity(_QueryEntity):
return
mappers, from_obj = self.mapper._with_polymorphic_args(
- cls_or_mappers, selectable)
+ cls_or_mappers, selectable
+ )
self._with_polymorphic = mappers
self._polymorphic_discriminator = polymorphic_on
self.selectable = from_obj
query._mapper_loads_polymorphically_with(
- self.mapper, sql_util.ColumnAdapter(
- from_obj, self.mapper._equivalent_columns))
+ self.mapper,
+ sql_util.ColumnAdapter(from_obj, self.mapper._equivalent_columns),
+ )
@property
def type(self):
@@ -3989,8 +4121,8 @@ class _MapperEntity(_QueryEntity):
# require row aliasing unconditionally.
if not adapter and self.mapper._requires_row_aliasing:
adapter = sql_util.ColumnAdapter(
- self.selectable,
- self.mapper._equivalent_columns)
+ self.selectable, self.mapper._equivalent_columns
+ )
if query._primary_entity is self:
only_load_props = query._only_load_props
@@ -4006,7 +4138,7 @@ class _MapperEntity(_QueryEntity):
adapter,
only_load_props=only_load_props,
refresh_state=refresh_state,
- polymorphic_discriminator=self._polymorphic_discriminator
+ polymorphic_discriminator=self._polymorphic_discriminator,
)
return _instance, self._label_name
@@ -4023,17 +4155,19 @@ class _MapperEntity(_QueryEntity):
# apply adaptation to the mapper's order_by if needed.
if adapter:
context.order_by = adapter.adapt_list(
- util.to_list(
- context.order_by
- )
+ util.to_list(context.order_by)
)
loading._setup_entity_query(
- context, self.mapper, self,
- self.path, adapter, context.primary_columns,
+ context,
+ self.mapper,
+ self,
+ self.path,
+ adapter,
+ context.primary_columns,
with_polymorphic=self._with_polymorphic,
only_load_props=query._only_load_props,
- polymorphic_discriminator=self._polymorphic_discriminator
+ polymorphic_discriminator=self._polymorphic_discriminator,
)
def __str__(self):
@@ -4091,9 +4225,10 @@ class Bundle(InspectionAttr):
self.name = self._label = name
self.exprs = exprs
self.c = self.columns = ColumnCollection()
- self.columns.update((getattr(col, "key", col._label), col)
- for col in exprs)
- self.single_entity = kw.pop('single_entity', self.single_entity)
+ self.columns.update(
+ (getattr(col, "key", col._label), col) for col in exprs
+ )
+ self.single_entity = kw.pop("single_entity", self.single_entity)
columns = None
"""A namespace of SQL expressions referred to by this :class:`.Bundle`.
@@ -4152,10 +4287,11 @@ class Bundle(InspectionAttr):
:ref:`bundles` - includes an example of subclassing.
"""
- keyed_tuple = util.lightweight_named_tuple('result', labels)
+ keyed_tuple = util.lightweight_named_tuple("result", labels)
def proc(row):
return keyed_tuple([proc(row) for proc in procs])
+
return proc
@@ -4235,8 +4371,10 @@ class _BundleEntity(_QueryEntity):
def row_processor(self, query, context, result):
procs, labels = zip(
- *[ent.row_processor(query, context, result)
- for ent in self._entities]
+ *[
+ ent.row_processor(query, context, result)
+ for ent in self._entities
+ ]
)
proc = self.bundle.create_row_processor(query, procs, labels)
@@ -4259,11 +4397,10 @@ class _ColumnEntity(_QueryEntity):
search_entities = False
check_column = True
_entity = None
- elif isinstance(column, (
- attributes.QueryableAttribute,
- interfaces.PropComparator
- )):
- _entity = getattr(column, '_parententity', None)
+ elif isinstance(
+ column, (attributes.QueryableAttribute, interfaces.PropComparator)
+ ):
+ _entity = getattr(column, "_parententity", None)
if _entity is not None:
search_entities = False
self._label_name = column.key
@@ -4274,7 +4411,7 @@ class _ColumnEntity(_QueryEntity):
return
if not isinstance(column, sql.ColumnElement):
- if hasattr(column, '_select_iterable'):
+ if hasattr(column, "_select_iterable"):
# break out an object like Table into
# individual columns
for c in column._select_iterable:
@@ -4286,10 +4423,10 @@ class _ColumnEntity(_QueryEntity):
raise sa_exc.InvalidRequestError(
"SQL expression, column, or mapped entity "
- "expected - got '%r'" % (column, )
+ "expected - got '%r'" % (column,)
)
elif not check_column:
- self._label_name = getattr(column, 'key', None)
+ self._label_name = getattr(column, "key", None)
search_entities = True
self.type = type_ = column.type
@@ -4301,7 +4438,7 @@ class _ColumnEntity(_QueryEntity):
# if the expression's identity has been changed
# due to adaption.
- if not column._label and not getattr(column, 'is_literal', False):
+ if not column._label and not getattr(column, "is_literal", False):
column = column.label(self._label_name)
query._entities.append(self)
@@ -4328,23 +4465,29 @@ class _ColumnEntity(_QueryEntity):
self._from_entities = set(self.entities)
else:
all_elements = [
- elem for elem in sql_util.surface_column_elements(
- column, include_scalar_selects=False)
- if 'parententity' in elem._annotations
+ elem
+ for elem in sql_util.surface_column_elements(
+ column, include_scalar_selects=False
+ )
+ if "parententity" in elem._annotations
]
- self.entities = util.unique_list([
- elem._annotations['parententity']
- for elem in all_elements
- if 'parententity' in elem._annotations
- ])
-
- self._from_entities = set([
- elem._annotations['parententity']
- for elem in all_elements
- if 'parententity' in elem._annotations
- and actual_froms.intersection(elem._from_objects)
- ])
+ self.entities = util.unique_list(
+ [
+ elem._annotations["parententity"]
+ for elem in all_elements
+ if "parententity" in elem._annotations
+ ]
+ )
+
+ self._from_entities = set(
+ [
+ elem._annotations["parententity"]
+ for elem in all_elements
+ if "parententity" in elem._annotations
+ and actual_froms.intersection(elem._from_objects)
+ ]
+ )
if self.entities:
self.entity_zero = self.entities[0]
self.mapper = self.entity_zero.mapper
@@ -4373,7 +4516,7 @@ class _ColumnEntity(_QueryEntity):
c.entities = self.entities
def setup_entity(self, ext_info, aliased_adapter):
- if 'selectable' not in self.__dict__:
+ if "selectable" not in self.__dict__:
self.selectable = ext_info.selectable
if self.actual_froms.intersection(ext_info.selectable._from_objects):
@@ -4386,12 +4529,13 @@ class _ColumnEntity(_QueryEntity):
# TODO: polymorphic subclasses ?
return entity is self.entity_zero
else:
- return not _is_aliased_class(self.entity_zero) and \
- entity.common_parent(self.entity_zero)
+ return not _is_aliased_class(
+ self.entity_zero
+ ) and entity.common_parent(self.entity_zero)
def row_processor(self, query, context, result):
- if ('fetch_column', self) in context.attributes:
- column = context.attributes[('fetch_column', self)]
+ if ("fetch_column", self) in context.attributes:
+ column = context.attributes[("fetch_column", self)]
else:
column = query._adapt_clause(self.column, False, True)
@@ -4417,7 +4561,7 @@ class _ColumnEntity(_QueryEntity):
context.froms += tuple(self.froms)
context.primary_columns.append(column)
- context.attributes[('fetch_column', self)] = column
+ context.attributes[("fetch_column", self)] = column
def __str__(self):
return str(self.column)
@@ -4425,22 +4569,44 @@ class _ColumnEntity(_QueryEntity):
class QueryContext(object):
__slots__ = (
- 'multi_row_eager_loaders', 'adapter', 'froms', 'for_update',
- 'query', 'session', 'autoflush', 'populate_existing',
- 'invoke_all_eagers', 'version_check', 'refresh_state',
- 'primary_columns', 'secondary_columns', 'eager_order_by',
- 'eager_joins', 'create_eager_joins', 'propagate_options',
- 'attributes', 'statement', 'from_clause', 'whereclause',
- 'order_by', 'labels', '_for_update_arg', 'runid', 'partials',
- 'post_load_paths', 'identity_token'
+ "multi_row_eager_loaders",
+ "adapter",
+ "froms",
+ "for_update",
+ "query",
+ "session",
+ "autoflush",
+ "populate_existing",
+ "invoke_all_eagers",
+ "version_check",
+ "refresh_state",
+ "primary_columns",
+ "secondary_columns",
+ "eager_order_by",
+ "eager_joins",
+ "create_eager_joins",
+ "propagate_options",
+ "attributes",
+ "statement",
+ "from_clause",
+ "whereclause",
+ "order_by",
+ "labels",
+ "_for_update_arg",
+ "runid",
+ "partials",
+ "post_load_paths",
+ "identity_token",
)
def __init__(self, query):
if query._statement is not None:
- if isinstance(query._statement, expression.SelectBase) and \
- not query._statement._textual and \
- not query._statement.use_labels:
+ if (
+ isinstance(query._statement, expression.SelectBase)
+ and not query._statement._textual
+ and not query._statement.use_labels
+ ):
self.statement = query._statement.apply_labels()
else:
self.statement = query._statement
@@ -4466,8 +4632,9 @@ class QueryContext(object):
self.eager_order_by = []
self.eager_joins = {}
self.create_eager_joins = []
- self.propagate_options = set(o for o in query._with_options if
- o.propagate_to_loaders)
+ self.propagate_options = set(
+ o for o in query._with_options if o.propagate_to_loaders
+ )
self.attributes = query._attributes.copy()
if self.refresh_state is not None:
self.identity_token = query._refresh_identity_token
@@ -4476,7 +4643,6 @@ class QueryContext(object):
class AliasOption(interfaces.MapperOption):
-
def __init__(self, alias):
r"""Return a :class:`.MapperOption` that will indicate to the :class:`.Query`
that the main table has been aliased.
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index e7896c423..e89d1542f 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -22,14 +22,23 @@ from . import dependency
from . import attributes
from ..sql.util import (
ClauseAdapter,
- join_condition, _shallow_annotate, visit_binary_product,
- _deep_deannotate, selectables_overlap, adapt_criterion_to_null
+ join_condition,
+ _shallow_annotate,
+ visit_binary_product,
+ _deep_deannotate,
+ selectables_overlap,
+ adapt_criterion_to_null,
)
from .base import state_str
from ..sql import operators, expression, visitors
-from .interfaces import (MANYTOMANY, MANYTOONE, ONETOMANY,
- StrategizedProperty, PropComparator)
+from .interfaces import (
+ MANYTOMANY,
+ MANYTOONE,
+ ONETOMANY,
+ StrategizedProperty,
+ PropComparator,
+)
from ..inspection import inspect
from . import mapper as mapperlib
import collections
@@ -51,8 +60,9 @@ def remote(expr):
:func:`.foreign`
"""
- return _annotate_columns(expression._clause_element_as_expr(expr),
- {"remote": True})
+ return _annotate_columns(
+ expression._clause_element_as_expr(expr), {"remote": True}
+ )
def foreign(expr):
@@ -72,8 +82,9 @@ def foreign(expr):
"""
- return _annotate_columns(expression._clause_element_as_expr(expr),
- {"foreign": True})
+ return _annotate_columns(
+ expression._clause_element_as_expr(expr), {"foreign": True}
+ )
@log.class_logger
@@ -90,36 +101,46 @@ class RelationshipProperty(StrategizedProperty):
"""
- strategy_wildcard_key = 'relationship'
+ strategy_wildcard_key = "relationship"
_dependency_processor = None
- def __init__(self, argument,
- secondary=None, primaryjoin=None,
- secondaryjoin=None,
- foreign_keys=None,
- uselist=None,
- order_by=False,
- backref=None,
- back_populates=None,
- post_update=False,
- cascade=False, extension=None,
- viewonly=False, lazy="select",
- collection_class=None, passive_deletes=False,
- passive_updates=True, remote_side=None,
- enable_typechecks=True, join_depth=None,
- comparator_factory=None,
- single_parent=False, innerjoin=False,
- distinct_target_key=None,
- doc=None,
- active_history=False,
- cascade_backrefs=True,
- load_on_pending=False,
- bake_queries=True,
- _local_remote_pairs=None,
- query_class=None,
- info=None,
- omit_join=None):
+ def __init__(
+ self,
+ argument,
+ secondary=None,
+ primaryjoin=None,
+ secondaryjoin=None,
+ foreign_keys=None,
+ uselist=None,
+ order_by=False,
+ backref=None,
+ back_populates=None,
+ post_update=False,
+ cascade=False,
+ extension=None,
+ viewonly=False,
+ lazy="select",
+ collection_class=None,
+ passive_deletes=False,
+ passive_updates=True,
+ remote_side=None,
+ enable_typechecks=True,
+ join_depth=None,
+ comparator_factory=None,
+ single_parent=False,
+ innerjoin=False,
+ distinct_target_key=None,
+ doc=None,
+ active_history=False,
+ cascade_backrefs=True,
+ load_on_pending=False,
+ bake_queries=True,
+ _local_remote_pairs=None,
+ query_class=None,
+ info=None,
+ omit_join=None,
+ ):
"""Provide a relationship between two mapped classes.
This corresponds to a parent-child or associative table relationship.
@@ -858,20 +879,22 @@ class RelationshipProperty(StrategizedProperty):
self.extension = extension
self.bake_queries = bake_queries
self.load_on_pending = load_on_pending
- self.comparator_factory = comparator_factory or \
- RelationshipProperty.Comparator
+ self.comparator_factory = (
+ comparator_factory or RelationshipProperty.Comparator
+ )
self.comparator = self.comparator_factory(self, None)
util.set_creation_order(self)
if info is not None:
self.info = info
- self.strategy_key = (("lazy", self.lazy), )
+ self.strategy_key = (("lazy", self.lazy),)
self._reverse_property = set()
- self.cascade = cascade if cascade is not False \
- else "save-update, merge"
+ self.cascade = (
+ cascade if cascade is not False else "save-update, merge"
+ )
self.order_by = order_by
@@ -881,7 +904,8 @@ class RelationshipProperty(StrategizedProperty):
if backref:
raise sa_exc.ArgumentError(
"backref and back_populates keyword arguments "
- "are mutually exclusive")
+ "are mutually exclusive"
+ )
self.backref = None
else:
self.backref = backref
@@ -919,7 +943,8 @@ class RelationshipProperty(StrategizedProperty):
_of_type = None
def __init__(
- self, prop, parentmapper, adapt_to_entity=None, of_type=None):
+ self, prop, parentmapper, adapt_to_entity=None, of_type=None
+ ):
"""Construction of :class:`.RelationshipProperty.Comparator`
is internal to the ORM's attribute mechanics.
@@ -931,9 +956,12 @@ class RelationshipProperty(StrategizedProperty):
self._of_type = of_type
def adapt_to_entity(self, adapt_to_entity):
- return self.__class__(self.property, self._parententity,
- adapt_to_entity=adapt_to_entity,
- of_type=self._of_type)
+ return self.__class__(
+ self.property,
+ self._parententity,
+ adapt_to_entity=adapt_to_entity,
+ of_type=self._of_type,
+ )
@util.memoized_property
def mapper(self):
@@ -963,11 +991,11 @@ class RelationshipProperty(StrategizedProperty):
else:
of_type = None
- pj, sj, source, dest, \
- secondary, target_adapter = self.property._create_joins(
- source_selectable=adapt_from,
- source_polymorphic=True,
- of_type=of_type)
+ pj, sj, source, dest, secondary, target_adapter = self.property._create_joins(
+ source_selectable=adapt_from,
+ source_polymorphic=True,
+ of_type=of_type,
+ )
if sj is not None:
return pj & sj
else:
@@ -983,17 +1011,20 @@ class RelationshipProperty(StrategizedProperty):
self.property,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
- of_type=cls)
+ of_type=cls,
+ )
def in_(self, other):
"""Produce an IN clause - this is not implemented
for :func:`~.orm.relationship`-based attributes at this time.
"""
- raise NotImplementedError('in_() not yet supported for '
- 'relationships. For a simple '
- 'many-to-one, use in_() against '
- 'the set of foreign key values.')
+ raise NotImplementedError(
+ "in_() not yet supported for "
+ "relationships. For a simple "
+ "many-to-one, use in_() against "
+ "the set of foreign key values."
+ )
__hash__ = None
@@ -1038,24 +1069,32 @@ class RelationshipProperty(StrategizedProperty):
if self.property.direction in [ONETOMANY, MANYTOMANY]:
return ~self._criterion_exists()
else:
- return _orm_annotate(self.property._optimized_compare(
- None, adapt_source=self.adapter))
+ return _orm_annotate(
+ self.property._optimized_compare(
+ None, adapt_source=self.adapter
+ )
+ )
elif self.property.uselist:
raise sa_exc.InvalidRequestError(
"Can't compare a collection to an object or collection; "
- "use contains() to test for membership.")
+ "use contains() to test for membership."
+ )
else:
return _orm_annotate(
self.property._optimized_compare(
- other, adapt_source=self.adapter))
+ other, adapt_source=self.adapter
+ )
+ )
def _criterion_exists(self, criterion=None, **kwargs):
- if getattr(self, '_of_type', None):
+ if getattr(self, "_of_type", None):
info = inspect(self._of_type)
- target_mapper, to_selectable, is_aliased_class = \
- info.mapper, info.selectable, info.is_aliased_class
- if self.property._is_self_referential and not \
- is_aliased_class:
+ target_mapper, to_selectable, is_aliased_class = (
+ info.mapper,
+ info.selectable,
+ info.is_aliased_class,
+ )
+ if self.property._is_self_referential and not is_aliased_class:
to_selectable = to_selectable.alias()
single_crit = target_mapper._single_table_criterion
@@ -1073,11 +1112,11 @@ class RelationshipProperty(StrategizedProperty):
else:
source_selectable = None
- pj, sj, source, dest, secondary, target_adapter = \
- self.property._create_joins(
- dest_polymorphic=True,
- dest_selectable=to_selectable,
- source_selectable=source_selectable)
+ pj, sj, source, dest, secondary, target_adapter = self.property._create_joins(
+ dest_polymorphic=True,
+ dest_selectable=to_selectable,
+ source_selectable=source_selectable,
+ )
for k in kwargs:
crit = getattr(self.property.mapper.class_, k) == kwargs[k]
@@ -1094,8 +1133,11 @@ class RelationshipProperty(StrategizedProperty):
else:
j = _orm_annotate(pj, exclude=self.property.remote_side)
- if criterion is not None and target_adapter and not \
- is_aliased_class:
+ if (
+ criterion is not None
+ and target_adapter
+ and not is_aliased_class
+ ):
# limit this adapter to annotated only?
criterion = target_adapter.traverse(criterion)
@@ -1106,16 +1148,19 @@ class RelationshipProperty(StrategizedProperty):
# to anything in the enclosing query.
if criterion is not None:
criterion = criterion._annotate(
- {'no_replacement_traverse': True})
+ {"no_replacement_traverse": True}
+ )
crit = j & sql.True_._ifnone(criterion)
if secondary is not None:
- ex = sql.exists([1], crit, from_obj=[dest, secondary]).\
- correlate_except(dest, secondary)
+ ex = sql.exists(
+ [1], crit, from_obj=[dest, secondary]
+ ).correlate_except(dest, secondary)
else:
- ex = sql.exists([1], crit, from_obj=dest).\
- correlate_except(dest)
+ ex = sql.exists([1], crit, from_obj=dest).correlate_except(
+ dest
+ )
return ex
def any(self, criterion=None, **kwargs):
@@ -1197,8 +1242,8 @@ class RelationshipProperty(StrategizedProperty):
"""
if self.property.uselist:
raise sa_exc.InvalidRequestError(
- "'has()' not implemented for collections. "
- "Use any().")
+ "'has()' not implemented for collections. " "Use any()."
+ )
return self._criterion_exists(criterion, **kwargs)
def contains(self, other, **kwargs):
@@ -1260,13 +1305,16 @@ class RelationshipProperty(StrategizedProperty):
if not self.property.uselist:
raise sa_exc.InvalidRequestError(
"'contains' not implemented for scalar "
- "attributes. Use ==")
+ "attributes. Use =="
+ )
clause = self.property._optimized_compare(
- other, adapt_source=self.adapter)
+ other, adapt_source=self.adapter
+ )
if self.property.secondaryjoin is not None:
- clause.negation_clause = \
- self.__negated_contains_or_equals(other)
+ clause.negation_clause = self.__negated_contains_or_equals(
+ other
+ )
return clause
@@ -1277,10 +1325,11 @@ class RelationshipProperty(StrategizedProperty):
def state_bindparam(x, state, col):
dict_ = state.dict
return sql.bindparam(
- x, unique=True,
+ x,
+ unique=True,
callable_=self.property._get_attr_w_warn_on_none(
self.property.mapper, state, dict_, col
- )
+ ),
)
def adapt(col):
@@ -1290,19 +1339,26 @@ class RelationshipProperty(StrategizedProperty):
return col
if self.property._use_get:
- return sql.and_(*[
- sql.or_(
- adapt(x) != state_bindparam(adapt(x), state, y),
- adapt(x) == None)
- for (x, y) in self.property.local_remote_pairs])
-
- criterion = sql.and_(*[
- x == y for (x, y) in
- zip(
- self.property.mapper.primary_key,
- self.property.mapper.primary_key_from_instance(other)
- )
- ])
+ return sql.and_(
+ *[
+ sql.or_(
+ adapt(x)
+ != state_bindparam(adapt(x), state, y),
+ adapt(x) == None,
+ )
+ for (x, y) in self.property.local_remote_pairs
+ ]
+ )
+
+ criterion = sql.and_(
+ *[
+ x == y
+ for (x, y) in zip(
+ self.property.mapper.primary_key,
+ self.property.mapper.primary_key_from_instance(other),
+ )
+ ]
+ )
return ~self._criterion_exists(criterion)
@@ -1347,8 +1403,11 @@ class RelationshipProperty(StrategizedProperty):
"""
if isinstance(other, (util.NoneType, expression.Null)):
if self.property.direction == MANYTOONE:
- return _orm_annotate(~self.property._optimized_compare(
- None, adapt_source=self.adapter))
+ return _orm_annotate(
+ ~self.property._optimized_compare(
+ None, adapt_source=self.adapter
+ )
+ )
else:
return self._criterion_exists()
@@ -1356,7 +1415,8 @@ class RelationshipProperty(StrategizedProperty):
raise sa_exc.InvalidRequestError(
"Can't compare a collection"
" to an object or collection; use "
- "contains() to test for membership.")
+ "contains() to test for membership."
+ )
else:
return _orm_annotate(self.__negated_contains_or_equals(other))
@@ -1374,12 +1434,19 @@ class RelationshipProperty(StrategizedProperty):
if insp.is_aliased_class:
adapt_source = insp._adapter.adapt_clause
return self._optimized_compare(
- instance, value_is_parent=True, adapt_source=adapt_source,
- alias_secondary=alias_secondary)
+ instance,
+ value_is_parent=True,
+ adapt_source=adapt_source,
+ alias_secondary=alias_secondary,
+ )
- def _optimized_compare(self, state, value_is_parent=False,
- adapt_source=None,
- alias_secondary=True):
+ def _optimized_compare(
+ self,
+ state,
+ value_is_parent=False,
+ adapt_source=None,
+ alias_secondary=True,
+ ):
if state is not None:
state = attributes.instance_state(state)
@@ -1387,17 +1454,19 @@ class RelationshipProperty(StrategizedProperty):
if state is None:
return self._lazy_none_clause(
- reverse_direction,
- adapt_source=adapt_source)
+ reverse_direction, adapt_source=adapt_source
+ )
if not reverse_direction:
- criterion, bind_to_col = \
- self._lazy_strategy._lazywhere, \
- self._lazy_strategy._bind_to_col
+ criterion, bind_to_col = (
+ self._lazy_strategy._lazywhere,
+ self._lazy_strategy._bind_to_col,
+ )
else:
- criterion, bind_to_col = \
- self._lazy_strategy._rev_lazywhere, \
- self._lazy_strategy._rev_bind_to_col
+ criterion, bind_to_col = (
+ self._lazy_strategy._rev_lazywhere,
+ self._lazy_strategy._rev_bind_to_col,
+ )
if reverse_direction:
mapper = self.mapper
@@ -1409,16 +1478,20 @@ class RelationshipProperty(StrategizedProperty):
def visit_bindparam(bindparam):
if bindparam._identifying_key in bind_to_col:
bindparam.callable = self._get_attr_w_warn_on_none(
- mapper, state, dict_,
- bind_to_col[bindparam._identifying_key])
+ mapper,
+ state,
+ dict_,
+ bind_to_col[bindparam._identifying_key],
+ )
if self.secondary is not None and alias_secondary:
- criterion = ClauseAdapter(
- self.secondary.alias()).\
- traverse(criterion)
+ criterion = ClauseAdapter(self.secondary.alias()).traverse(
+ criterion
+ )
criterion = visitors.cloned_traverse(
- criterion, {}, {'bindparam': visit_bindparam})
+ criterion, {}, {"bindparam": visit_bindparam}
+ )
if adapt_source:
criterion = adapt_source(criterion)
@@ -1483,25 +1556,27 @@ class RelationshipProperty(StrategizedProperty):
# only if we can't get a value now due to detachment do we return
# the last known value
current_value = mapper._get_state_attr_by_column(
- state, dict_, column,
+ state,
+ dict_,
+ column,
passive=attributes.PASSIVE_RETURN_NEVER_SET
if state.persistent
- else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK)
+ else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK,
+ )
if current_value is attributes.NEVER_SET:
if not existing_is_available:
raise sa_exc.InvalidRequestError(
"Can't resolve value for column %s on object "
- "%s; no value has been set for this column" % (
- column, state_str(state))
+ "%s; no value has been set for this column"
+ % (column, state_str(state))
)
elif current_value is attributes.PASSIVE_NO_RESULT:
if not existing_is_available:
raise sa_exc.InvalidRequestError(
"Can't resolve value for column %s on object "
"%s; the object is detached and the value was "
- "expired" % (
- column, state_str(state))
+ "expired" % (column, state_str(state))
)
else:
to_return = current_value
@@ -1510,19 +1585,23 @@ class RelationshipProperty(StrategizedProperty):
"Got None for value of column %s; this is unsupported "
"for a relationship comparison and will not "
"currently produce an IS comparison "
- "(but may in a future release)" % column)
+ "(but may in a future release)" % column
+ )
return to_return
+
return _go
def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
if not reverse_direction:
- criterion, bind_to_col = \
- self._lazy_strategy._lazywhere, \
- self._lazy_strategy._bind_to_col
+ criterion, bind_to_col = (
+ self._lazy_strategy._lazywhere,
+ self._lazy_strategy._bind_to_col,
+ )
else:
- criterion, bind_to_col = \
- self._lazy_strategy._rev_lazywhere, \
- self._lazy_strategy._rev_bind_to_col
+ criterion, bind_to_col = (
+ self._lazy_strategy._rev_lazywhere,
+ self._lazy_strategy._rev_bind_to_col,
+ )
criterion = adapt_criterion_to_null(criterion, bind_to_col)
@@ -1533,13 +1612,17 @@ class RelationshipProperty(StrategizedProperty):
def __str__(self):
return str(self.parent.class_.__name__) + "." + self.key
- def merge(self,
- session,
- source_state,
- source_dict,
- dest_state,
- dest_dict,
- load, _recursive, _resolve_conflict_map):
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
if load:
for r in self._reverse_property:
@@ -1553,9 +1636,10 @@ class RelationshipProperty(StrategizedProperty):
return
if self.uselist:
- instances = source_state.get_impl(self.key).\
- get(source_state, source_dict)
- if hasattr(instances, '_sa_adapter'):
+ instances = source_state.get_impl(self.key).get(
+ source_state, source_dict
+ )
+ if hasattr(instances, "_sa_adapter"):
# convert collections to adapters to get a true iterator
instances = instances._sa_adapter
@@ -1573,21 +1657,25 @@ class RelationshipProperty(StrategizedProperty):
current_dict = attributes.instance_dict(current)
_recursive[(current_state, self)] = True
obj = session._merge(
- current_state, current_dict,
- load=load, _recursive=_recursive,
- _resolve_conflict_map=_resolve_conflict_map)
+ current_state,
+ current_dict,
+ load=load,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
if obj is not None:
dest_list.append(obj)
if not load:
- coll = attributes.init_state_collection(dest_state,
- dest_dict, self.key)
+ coll = attributes.init_state_collection(
+ dest_state, dest_dict, self.key
+ )
for c in dest_list:
coll.append_without_event(c)
else:
dest_state.get_impl(self.key).set(
- dest_state, dest_dict, dest_list,
- _adapt=False)
+ dest_state, dest_dict, dest_list, _adapt=False
+ )
else:
current = source_dict[self.key]
if current is not None:
@@ -1595,20 +1683,25 @@ class RelationshipProperty(StrategizedProperty):
current_dict = attributes.instance_dict(current)
_recursive[(current_state, self)] = True
obj = session._merge(
- current_state, current_dict,
- load=load, _recursive=_recursive,
- _resolve_conflict_map=_resolve_conflict_map)
+ current_state,
+ current_dict,
+ load=load,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
else:
obj = None
if not load:
dest_dict[self.key] = obj
else:
- dest_state.get_impl(self.key).set(dest_state,
- dest_dict, obj, None)
+ dest_state.get_impl(self.key).set(
+ dest_state, dest_dict, obj, None
+ )
- def _value_as_iterable(self, state, dict_, key,
- passive=attributes.PASSIVE_OFF):
+ def _value_as_iterable(
+ self, state, dict_, key, passive=attributes.PASSIVE_OFF
+ ):
"""Return a list of tuples (state, obj) for the given
key.
@@ -1619,34 +1712,36 @@ class RelationshipProperty(StrategizedProperty):
x = impl.get(state, dict_, passive=passive)
if x is attributes.PASSIVE_NO_RESULT or x is None:
return []
- elif hasattr(impl, 'get_collection'):
+ elif hasattr(impl, "get_collection"):
return [
- (attributes.instance_state(o), o) for o in
- impl.get_collection(state, dict_, x, passive=passive)
+ (attributes.instance_state(o), o)
+ for o in impl.get_collection(state, dict_, x, passive=passive)
]
else:
return [(attributes.instance_state(x), x)]
- def cascade_iterator(self, type_, state, dict_,
- visited_states, halt_on=None):
+ def cascade_iterator(
+ self, type_, state, dict_, visited_states, halt_on=None
+ ):
# assert type_ in self._cascade
# only actively lazy load on the 'delete' cascade
- if type_ != 'delete' or self.passive_deletes:
+ if type_ != "delete" or self.passive_deletes:
passive = attributes.PASSIVE_NO_INITIALIZE
else:
passive = attributes.PASSIVE_OFF
- if type_ == 'save-update':
- tuples = state.manager[self.key].impl.\
- get_all_pending(state, dict_)
+ if type_ == "save-update":
+ tuples = state.manager[self.key].impl.get_all_pending(state, dict_)
else:
- tuples = self._value_as_iterable(state, dict_, self.key,
- passive=passive)
+ tuples = self._value_as_iterable(
+ state, dict_, self.key, passive=passive
+ )
- skip_pending = type_ == 'refresh-expire' and 'delete-orphan' \
- not in self._cascade
+ skip_pending = (
+ type_ == "refresh-expire" and "delete-orphan" not in self._cascade
+ )
for instance_state, c in tuples:
if instance_state in visited_states:
@@ -1670,13 +1765,12 @@ class RelationshipProperty(StrategizedProperty):
instance_mapper = instance_state.manager.mapper
if not instance_mapper.isa(self.mapper.class_manager.mapper):
- raise AssertionError("Attribute '%s' on class '%s' "
- "doesn't handle objects "
- "of type '%s'" % (
- self.key,
- self.parent.class_,
- c.__class__
- ))
+ raise AssertionError(
+ "Attribute '%s' on class '%s' "
+ "doesn't handle objects "
+ "of type '%s'"
+ % (self.key, self.parent.class_, c.__class__)
+ )
visited_states.add(instance_state)
@@ -1689,18 +1783,22 @@ class RelationshipProperty(StrategizedProperty):
if not other.mapper.common_parent(self.parent):
raise sa_exc.ArgumentError(
- 'reverse_property %r on '
- 'relationship %s references relationship %s, which '
- 'does not reference mapper %s' %
- (key, self, other, self.parent))
+ "reverse_property %r on "
+ "relationship %s references relationship %s, which "
+ "does not reference mapper %s"
+ % (key, self, other, self.parent)
+ )
- if self.direction in (ONETOMANY, MANYTOONE) and self.direction \
- == other.direction:
+ if (
+ self.direction in (ONETOMANY, MANYTOONE)
+ and self.direction == other.direction
+ ):
raise sa_exc.ArgumentError(
- '%s and back-reference %s are '
- 'both of the same direction %r. Did you mean to '
- 'set remote_side on the many-to-one side ?' %
- (other, self, self.direction))
+ "%s and back-reference %s are "
+ "both of the same direction %r. Did you mean to "
+ "set remote_side on the many-to-one side ?"
+ % (other, self, self.direction)
+ )
@util.memoized_property
def mapper(self):
@@ -1710,22 +1808,23 @@ class RelationshipProperty(StrategizedProperty):
This is a lazy-initializing static attribute.
"""
- if util.callable(self.argument) and \
- not isinstance(self.argument, (type, mapperlib.Mapper)):
+ if util.callable(self.argument) and not isinstance(
+ self.argument, (type, mapperlib.Mapper)
+ ):
argument = self.argument()
else:
argument = self.argument
if isinstance(argument, type):
- mapper_ = mapperlib.class_mapper(argument,
- configure=False)
+ mapper_ = mapperlib.class_mapper(argument, configure=False)
elif isinstance(self.argument, mapperlib.Mapper):
mapper_ = argument
else:
raise sa_exc.ArgumentError(
"relationship '%s' expects "
"a class or a mapper argument (received: %s)"
- % (self.key, type(argument)))
+ % (self.key, type(argument))
+ )
return mapper_
@util.memoized_property
@@ -1759,8 +1858,12 @@ class RelationshipProperty(StrategizedProperty):
# deferred initialization. This technique is used
# by declarative "string configs" and some recipes.
for attr in (
- 'order_by', 'primaryjoin', 'secondaryjoin',
- 'secondary', '_user_defined_foreign_keys', 'remote_side',
+ "order_by",
+ "primaryjoin",
+ "secondaryjoin",
+ "secondary",
+ "_user_defined_foreign_keys",
+ "remote_side",
):
attr_value = getattr(self, attr)
if util.callable(attr_value):
@@ -1768,11 +1871,15 @@ class RelationshipProperty(StrategizedProperty):
# remove "annotations" which are present if mapped class
# descriptors are used to create the join expression.
- for attr in 'primaryjoin', 'secondaryjoin':
+ for attr in "primaryjoin", "secondaryjoin":
val = getattr(self, attr)
if val is not None:
- setattr(self, attr, _orm_deannotate(
- expression._only_column_elements(val, attr))
+ setattr(
+ self,
+ attr,
+ _orm_deannotate(
+ expression._only_column_elements(val, attr)
+ ),
)
# ensure expressions in self.order_by, foreign_keys,
@@ -1780,21 +1887,18 @@ class RelationshipProperty(StrategizedProperty):
if self.order_by is not False and self.order_by is not None:
self.order_by = [
expression._only_column_elements(x, "order_by")
- for x in
- util.to_list(self.order_by)]
-
- self._user_defined_foreign_keys = \
- util.column_set(
- expression._only_column_elements(x, "foreign_keys")
- for x in util.to_column_set(
- self._user_defined_foreign_keys
- ))
-
- self.remote_side = \
- util.column_set(
- expression._only_column_elements(x, "remote_side")
- for x in
- util.to_column_set(self.remote_side))
+ for x in util.to_list(self.order_by)
+ ]
+
+ self._user_defined_foreign_keys = util.column_set(
+ expression._only_column_elements(x, "foreign_keys")
+ for x in util.to_column_set(self._user_defined_foreign_keys)
+ )
+
+ self.remote_side = util.column_set(
+ expression._only_column_elements(x, "remote_side")
+ for x in util.to_column_set(self.remote_side)
+ )
self.target = self.mapper.mapped_table
@@ -1815,7 +1919,7 @@ class RelationshipProperty(StrategizedProperty):
self_referential=self._is_self_referential,
prop=self,
support_sync=not self.viewonly,
- can_be_synced_fn=self._columns_are_mapped
+ can_be_synced_fn=self._columns_are_mapped,
)
self.primaryjoin = jc.primaryjoin
self.secondaryjoin = jc.secondaryjoin
@@ -1832,16 +1936,20 @@ class RelationshipProperty(StrategizedProperty):
inheritance conflicts."""
if self.parent.non_primary and not mapperlib.class_mapper(
- self.parent.class_,
- configure=False).has_property(self.key):
+ self.parent.class_, configure=False
+ ).has_property(self.key):
raise sa_exc.ArgumentError(
"Attempting to assign a new "
"relationship '%s' to a non-primary mapper on "
"class '%s'. New relationships can only be added "
"to the primary mapper, i.e. the very first mapper "
- "created for class '%s' " %
- (self.key, self.parent.class_.__name__,
- self.parent.class_.__name__))
+ "created for class '%s' "
+ % (
+ self.key,
+ self.parent.class_.__name__,
+ self.parent.class_.__name__,
+ )
+ )
def _get_cascade(self):
"""Return the current cascade setting for this
@@ -1851,7 +1959,7 @@ class RelationshipProperty(StrategizedProperty):
def _set_cascade(self, cascade):
cascade = CascadeOptions(cascade)
- if 'mapper' in self.__dict__:
+ if "mapper" in self.__dict__:
self._check_cascade_settings(cascade)
self._cascade = cascade
@@ -1861,27 +1969,31 @@ class RelationshipProperty(StrategizedProperty):
cascade = property(_get_cascade, _set_cascade)
def _check_cascade_settings(self, cascade):
- if cascade.delete_orphan and not self.single_parent \
- and (self.direction is MANYTOMANY or self.direction
- is MANYTOONE):
+ if (
+ cascade.delete_orphan
+ and not self.single_parent
+ and (self.direction is MANYTOMANY or self.direction is MANYTOONE)
+ ):
raise sa_exc.ArgumentError(
- 'On %s, delete-orphan cascade is not supported '
- 'on a many-to-many or many-to-one relationship '
- 'when single_parent is not set. Set '
- 'single_parent=True on the relationship().'
- % self)
+ "On %s, delete-orphan cascade is not supported "
+ "on a many-to-many or many-to-one relationship "
+ "when single_parent is not set. Set "
+ "single_parent=True on the relationship()." % self
+ )
if self.direction is MANYTOONE and self.passive_deletes:
- util.warn("On %s, 'passive_deletes' is normally configured "
- "on one-to-many, one-to-one, many-to-many "
- "relationships only."
- % self)
-
- if self.passive_deletes == 'all' and \
- ("delete" in cascade or
- "delete-orphan" in cascade):
+ util.warn(
+ "On %s, 'passive_deletes' is normally configured "
+ "on one-to-many, one-to-one, many-to-many "
+ "relationships only." % self
+ )
+
+ if self.passive_deletes == "all" and (
+ "delete" in cascade or "delete-orphan" in cascade
+ ):
raise sa_exc.ArgumentError(
"On %s, can't set passive_deletes='all' in conjunction "
- "with 'delete' or 'delete-orphan' cascade" % self)
+ "with 'delete' or 'delete-orphan' cascade" % self
+ )
if cascade.delete_orphan:
self.mapper.primary_mapper()._delete_orphans.append(
@@ -1894,8 +2006,10 @@ class RelationshipProperty(StrategizedProperty):
"""
- return self.key in mapper.relationships and \
- mapper.relationships[self.key] is self
+ return (
+ self.key in mapper.relationships
+ and mapper.relationships[self.key] is self
+ )
def _columns_are_mapped(self, *cols):
"""Return True if all columns in the given collection are
@@ -1903,11 +2017,14 @@ class RelationshipProperty(StrategizedProperty):
"""
for c in cols:
- if self.secondary is not None \
- and self.secondary.c.contains_column(c):
+ if (
+ self.secondary is not None
+ and self.secondary.c.contains_column(c)
+ ):
continue
- if not self.parent.mapped_table.c.contains_column(c) and \
- not self.target.c.contains_column(c):
+ if not self.parent.mapped_table.c.contains_column(
+ c
+ ) and not self.target.c.contains_column(c):
return False
return True
@@ -1925,15 +2042,17 @@ class RelationshipProperty(StrategizedProperty):
mapper = self.mapper.primary_mapper()
if not mapper.concrete:
- check = set(mapper.iterate_to_root()).\
- union(mapper.self_and_descendants)
+ check = set(mapper.iterate_to_root()).union(
+ mapper.self_and_descendants
+ )
for m in check:
if m.has_property(backref_key) and not m.concrete:
raise sa_exc.ArgumentError(
"Error creating backref "
"'%s' on relationship '%s': property of that "
- "name exists on mapper '%s'" %
- (backref_key, self, m))
+ "name exists on mapper '%s'"
+ % (backref_key, self, m)
+ )
# determine primaryjoin/secondaryjoin for the
# backref. Use the one we had, so that
@@ -1944,35 +2063,42 @@ class RelationshipProperty(StrategizedProperty):
# secondaryjoin. use the annotated
# pj/sj on the _join_condition.
pj = kwargs.pop(
- 'primaryjoin',
- self._join_condition.secondaryjoin_minus_local)
+ "primaryjoin",
+ self._join_condition.secondaryjoin_minus_local,
+ )
sj = kwargs.pop(
- 'secondaryjoin',
- self._join_condition.primaryjoin_minus_local)
+ "secondaryjoin",
+ self._join_condition.primaryjoin_minus_local,
+ )
else:
pj = kwargs.pop(
- 'primaryjoin',
- self._join_condition.primaryjoin_reverse_remote)
- sj = kwargs.pop('secondaryjoin', None)
+ "primaryjoin",
+ self._join_condition.primaryjoin_reverse_remote,
+ )
+ sj = kwargs.pop("secondaryjoin", None)
if sj:
raise sa_exc.InvalidRequestError(
"Can't assign 'secondaryjoin' on a backref "
"against a non-secondary relationship."
)
- foreign_keys = kwargs.pop('foreign_keys',
- self._user_defined_foreign_keys)
+ foreign_keys = kwargs.pop(
+ "foreign_keys", self._user_defined_foreign_keys
+ )
parent = self.parent.primary_mapper()
- kwargs.setdefault('viewonly', self.viewonly)
- kwargs.setdefault('post_update', self.post_update)
- kwargs.setdefault('passive_updates', self.passive_updates)
+ kwargs.setdefault("viewonly", self.viewonly)
+ kwargs.setdefault("post_update", self.post_update)
+ kwargs.setdefault("passive_updates", self.passive_updates)
self.back_populates = backref_key
relationship = RelationshipProperty(
- parent, self.secondary,
- pj, sj,
+ parent,
+ self.secondary,
+ pj,
+ sj,
foreign_keys=foreign_keys,
back_populates=self.key,
- **kwargs)
+ **kwargs
+ )
mapper._configure_property(backref_key, relationship)
if self.back_populates:
@@ -1982,8 +2108,9 @@ class RelationshipProperty(StrategizedProperty):
if self.uselist is None:
self.uselist = self.direction is not MANYTOONE
if not self.viewonly:
- self._dependency_processor = \
- dependency.DependencyProcessor.from_relationship(self)
+ self._dependency_processor = dependency.DependencyProcessor.from_relationship(
+ self
+ )
@util.memoized_property
def _use_get(self):
@@ -1997,9 +2124,14 @@ class RelationshipProperty(StrategizedProperty):
def _is_self_referential(self):
return self.mapper.common_parent(self.parent)
- def _create_joins(self, source_polymorphic=False,
- source_selectable=None, dest_polymorphic=False,
- dest_selectable=None, of_type=None):
+ def _create_joins(
+ self,
+ source_polymorphic=False,
+ source_selectable=None,
+ dest_polymorphic=False,
+ dest_selectable=None,
+ of_type=None,
+ ):
if source_selectable is None:
if source_polymorphic and self.parent.with_polymorphic:
source_selectable = self.parent._with_polymorphic_selectable
@@ -2023,16 +2155,21 @@ class RelationshipProperty(StrategizedProperty):
single_crit = dest_mapper._single_table_criterion
aliased = aliased or (source_selectable is not None)
- primaryjoin, secondaryjoin, secondary, target_adapter, dest_selectable = \
- self._join_condition.join_targets(
- source_selectable, dest_selectable, aliased, single_crit
- )
+ primaryjoin, secondaryjoin, secondary, target_adapter, dest_selectable = self._join_condition.join_targets(
+ source_selectable, dest_selectable, aliased, single_crit
+ )
if source_selectable is None:
source_selectable = self.parent.local_table
if dest_selectable is None:
dest_selectable = self.mapper.local_table
- return (primaryjoin, secondaryjoin, source_selectable,
- dest_selectable, secondary, target_adapter)
+ return (
+ primaryjoin,
+ secondaryjoin,
+ source_selectable,
+ dest_selectable,
+ secondary,
+ target_adapter,
+ )
def _annotate_columns(element, annotations):
@@ -2048,24 +2185,25 @@ def _annotate_columns(element, annotations):
class JoinCondition(object):
- def __init__(self,
- parent_selectable,
- child_selectable,
- parent_local_selectable,
- child_local_selectable,
- primaryjoin=None,
- secondary=None,
- secondaryjoin=None,
- parent_equivalents=None,
- child_equivalents=None,
- consider_as_foreign_keys=None,
- local_remote_pairs=None,
- remote_side=None,
- self_referential=False,
- prop=None,
- support_sync=True,
- can_be_synced_fn=lambda *c: True
- ):
+ def __init__(
+ self,
+ parent_selectable,
+ child_selectable,
+ parent_local_selectable,
+ child_local_selectable,
+ primaryjoin=None,
+ secondary=None,
+ secondaryjoin=None,
+ parent_equivalents=None,
+ child_equivalents=None,
+ consider_as_foreign_keys=None,
+ local_remote_pairs=None,
+ remote_side=None,
+ self_referential=False,
+ prop=None,
+ support_sync=True,
+ can_be_synced_fn=lambda *c: True,
+ ):
self.parent_selectable = parent_selectable
self.parent_local_selectable = parent_local_selectable
self.child_selectable = child_selectable
@@ -2100,27 +2238,41 @@ class JoinCondition(object):
if self.prop is None:
return
log = self.prop.logger
- log.info('%s setup primary join %s', self.prop,
- self.primaryjoin)
- log.info('%s setup secondary join %s', self.prop,
- self.secondaryjoin)
- log.info('%s synchronize pairs [%s]', self.prop,
- ','.join('(%s => %s)' % (l, r) for (l, r) in
- self.synchronize_pairs))
- log.info('%s secondary synchronize pairs [%s]', self.prop,
- ','.join('(%s => %s)' % (l, r) for (l, r) in
- self.secondary_synchronize_pairs or []))
- log.info('%s local/remote pairs [%s]', self.prop,
- ','.join('(%s / %s)' % (l, r) for (l, r) in
- self.local_remote_pairs))
- log.info('%s remote columns [%s]', self.prop,
- ','.join('%s' % col for col in self.remote_columns)
- )
- log.info('%s local columns [%s]', self.prop,
- ','.join('%s' % col for col in self.local_columns)
- )
- log.info('%s relationship direction %s', self.prop,
- self.direction)
+ log.info("%s setup primary join %s", self.prop, self.primaryjoin)
+ log.info("%s setup secondary join %s", self.prop, self.secondaryjoin)
+ log.info(
+ "%s synchronize pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s => %s)" % (l, r) for (l, r) in self.synchronize_pairs
+ ),
+ )
+ log.info(
+ "%s secondary synchronize pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s => %s)" % (l, r)
+ for (l, r) in self.secondary_synchronize_pairs or []
+ ),
+ )
+ log.info(
+ "%s local/remote pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s / %s)" % (l, r) for (l, r) in self.local_remote_pairs
+ ),
+ )
+ log.info(
+ "%s remote columns [%s]",
+ self.prop,
+ ",".join("%s" % col for col in self.remote_columns),
+ )
+ log.info(
+ "%s local columns [%s]",
+ self.prop,
+ ",".join("%s" % col for col in self.local_columns),
+ )
+ log.info("%s relationship direction %s", self.prop, self.direction)
def _sanitize_joins(self):
"""remove the parententity annotation from our join conditions which
@@ -2133,10 +2285,12 @@ class JoinCondition(object):
"""
self.primaryjoin = _deep_deannotate(
- self.primaryjoin, values=("parententity",))
+ self.primaryjoin, values=("parententity",)
+ )
if self.secondaryjoin is not None:
self.secondaryjoin = _deep_deannotate(
- self.secondaryjoin, values=("parententity",))
+ self.secondaryjoin, values=("parententity",)
+ )
def _determine_joins(self):
"""Determine the 'primaryjoin' and 'secondaryjoin' attributes,
@@ -2150,7 +2304,8 @@ class JoinCondition(object):
raise sa_exc.ArgumentError(
"Property %s specified with secondary "
"join condition but "
- "no secondary argument" % self.prop)
+ "no secondary argument" % self.prop
+ )
# find a join between the given mapper's mapped table and
# the given table. will try the mapper's local table first
@@ -2161,30 +2316,27 @@ class JoinCondition(object):
consider_as_foreign_keys = self.consider_as_foreign_keys or None
if self.secondary is not None:
if self.secondaryjoin is None:
- self.secondaryjoin = \
- join_condition(
- self.child_selectable,
- self.secondary,
- a_subset=self.child_local_selectable,
- consider_as_foreign_keys=consider_as_foreign_keys
- )
+ self.secondaryjoin = join_condition(
+ self.child_selectable,
+ self.secondary,
+ a_subset=self.child_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
if self.primaryjoin is None:
- self.primaryjoin = \
- join_condition(
- self.parent_selectable,
- self.secondary,
- a_subset=self.parent_local_selectable,
- consider_as_foreign_keys=consider_as_foreign_keys
- )
+ self.primaryjoin = join_condition(
+ self.parent_selectable,
+ self.secondary,
+ a_subset=self.parent_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
else:
if self.primaryjoin is None:
- self.primaryjoin = \
- join_condition(
- self.parent_selectable,
- self.child_selectable,
- a_subset=self.parent_local_selectable,
- consider_as_foreign_keys=consider_as_foreign_keys
- )
+ self.primaryjoin = join_condition(
+ self.parent_selectable,
+ self.child_selectable,
+ a_subset=self.parent_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
except sa_exc.NoForeignKeysError:
if self.secondary is not None:
raise sa_exc.NoForeignKeysError(
@@ -2195,7 +2347,8 @@ class JoinCondition(object):
"Ensure that referencing columns are associated "
"with a ForeignKey or ForeignKeyConstraint, or "
"specify 'primaryjoin' and 'secondaryjoin' "
- "expressions." % (self.prop, self.secondary))
+ "expressions." % (self.prop, self.secondary)
+ )
else:
raise sa_exc.NoForeignKeysError(
"Could not determine join "
@@ -2204,7 +2357,8 @@ class JoinCondition(object):
"linking these tables. "
"Ensure that referencing columns are associated "
"with a ForeignKey or ForeignKeyConstraint, or "
- "specify a 'primaryjoin' expression." % self.prop)
+ "specify a 'primaryjoin' expression." % self.prop
+ )
except sa_exc.AmbiguousForeignKeysError:
if self.secondary is not None:
raise sa_exc.AmbiguousForeignKeysError(
@@ -2216,8 +2370,8 @@ class JoinCondition(object):
"argument, providing a list of those columns which "
"should be counted as containing a foreign key "
"reference from the secondary table to each of the "
- "parent and child tables."
- % (self.prop, self.secondary))
+ "parent and child tables." % (self.prop, self.secondary)
+ )
else:
raise sa_exc.AmbiguousForeignKeysError(
"Could not determine join "
@@ -2226,8 +2380,8 @@ class JoinCondition(object):
"paths linking the tables. Specify the "
"'foreign_keys' argument, providing a list of those "
"columns which should be counted as containing a "
- "foreign key reference to the parent table."
- % self.prop)
+ "foreign key reference to the parent table." % self.prop
+ )
@property
def primaryjoin_minus_local(self):
@@ -2235,8 +2389,7 @@ class JoinCondition(object):
@property
def secondaryjoin_minus_local(self):
- return _deep_deannotate(self.secondaryjoin,
- values=("local", "remote"))
+ return _deep_deannotate(self.secondaryjoin, values=("local", "remote"))
@util.memoized_property
def primaryjoin_reverse_remote(self):
@@ -2250,24 +2403,26 @@ class JoinCondition(object):
"""
if self._has_remote_annotations:
+
def replace(element):
if "remote" in element._annotations:
v = element._annotations.copy()
- del v['remote']
- v['local'] = True
+ del v["remote"]
+ v["local"] = True
return element._with_annotations(v)
elif "local" in element._annotations:
v = element._annotations.copy()
- del v['local']
- v['remote'] = True
+ del v["local"]
+ v["remote"] = True
return element._with_annotations(v)
- return visitors.replacement_traverse(
- self.primaryjoin, {}, replace)
+
+ return visitors.replacement_traverse(self.primaryjoin, {}, replace)
else:
if self._has_foreign_annotations:
# TODO: coverage
- return _deep_deannotate(self.primaryjoin,
- values=("local", "remote"))
+ return _deep_deannotate(
+ self.primaryjoin, values=("local", "remote")
+ )
else:
return _deep_deannotate(self.primaryjoin)
@@ -2304,16 +2459,13 @@ class JoinCondition(object):
def check_fk(col):
if col in self.consider_as_foreign_keys:
return col._annotate({"foreign": True})
+
self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin,
- {},
- check_fk
+ self.primaryjoin, {}, check_fk
)
if self.secondaryjoin is not None:
self.secondaryjoin = visitors.replacement_traverse(
- self.secondaryjoin,
- {},
- check_fk
+ self.secondaryjoin, {}, check_fk
)
def _annotate_present_fks(self):
@@ -2323,8 +2475,7 @@ class JoinCondition(object):
secondarycols = set()
def is_foreign(a, b):
- if isinstance(a, schema.Column) and \
- isinstance(b, schema.Column):
+ if isinstance(a, schema.Column) and isinstance(b, schema.Column):
if a.references(b):
return a
elif b.references(a):
@@ -2337,31 +2488,30 @@ class JoinCondition(object):
return b
def visit_binary(binary):
- if not isinstance(binary.left, sql.ColumnElement) or \
- not isinstance(binary.right, sql.ColumnElement):
+ if not isinstance(
+ binary.left, sql.ColumnElement
+ ) or not isinstance(binary.right, sql.ColumnElement):
return
- if "foreign" not in binary.left._annotations and \
- "foreign" not in binary.right._annotations:
+ if (
+ "foreign" not in binary.left._annotations
+ and "foreign" not in binary.right._annotations
+ ):
col = is_foreign(binary.left, binary.right)
if col is not None:
if col.compare(binary.left):
- binary.left = binary.left._annotate(
- {"foreign": True})
+ binary.left = binary.left._annotate({"foreign": True})
elif col.compare(binary.right):
binary.right = binary.right._annotate(
- {"foreign": True})
+ {"foreign": True}
+ )
self.primaryjoin = visitors.cloned_traverse(
- self.primaryjoin,
- {},
- {"binary": visit_binary}
+ self.primaryjoin, {}, {"binary": visit_binary}
)
if self.secondaryjoin is not None:
self.secondaryjoin = visitors.cloned_traverse(
- self.secondaryjoin,
- {},
- {"binary": visit_binary}
+ self.secondaryjoin, {}, {"binary": visit_binary}
)
def _refers_to_parent_table(self):
@@ -2376,26 +2526,24 @@ class JoinCondition(object):
def visit_binary(binary):
c, f = binary.left, binary.right
if (
- isinstance(c, expression.ColumnClause) and
- isinstance(f, expression.ColumnClause) and
- pt.is_derived_from(c.table) and
- pt.is_derived_from(f.table) and
- mt.is_derived_from(c.table) and
- mt.is_derived_from(f.table)
+ isinstance(c, expression.ColumnClause)
+ and isinstance(f, expression.ColumnClause)
+ and pt.is_derived_from(c.table)
+ and pt.is_derived_from(f.table)
+ and mt.is_derived_from(c.table)
+ and mt.is_derived_from(f.table)
):
result[0] = True
- visitors.traverse(
- self.primaryjoin,
- {},
- {"binary": visit_binary}
- )
+
+ visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary})
return result[0]
def _tables_overlap(self):
"""Return True if parent/child tables have some overlap."""
return selectables_overlap(
- self.parent_selectable, self.child_selectable)
+ self.parent_selectable, self.child_selectable
+ )
def _annotate_remote(self):
"""Annotate the primaryjoin and secondaryjoin
@@ -2411,7 +2559,9 @@ class JoinCondition(object):
elif self._local_remote_pairs or self._remote_side:
self._annotate_remote_from_args()
elif self._refers_to_parent_table():
- self._annotate_selfref(lambda col: "foreign" in col._annotations, False)
+ self._annotate_selfref(
+ lambda col: "foreign" in col._annotations, False
+ )
elif self._tables_overlap():
self._annotate_remote_with_overlap()
else:
@@ -2422,35 +2572,40 @@ class JoinCondition(object):
when 'secondary' is present.
"""
+
def repl(element):
if self.secondary.c.contains_column(element):
return element._annotate({"remote": True})
+
self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, repl)
+ self.primaryjoin, {}, repl
+ )
self.secondaryjoin = visitors.replacement_traverse(
- self.secondaryjoin, {}, repl)
+ self.secondaryjoin, {}, repl
+ )
def _annotate_selfref(self, fn, remote_side_given):
"""annotate 'remote' in primaryjoin, secondaryjoin
when the relationship is detected as self-referential.
"""
+
def visit_binary(binary):
equated = binary.left.compare(binary.right)
- if isinstance(binary.left, expression.ColumnClause) and \
- isinstance(binary.right, expression.ColumnClause):
+ if isinstance(binary.left, expression.ColumnClause) and isinstance(
+ binary.right, expression.ColumnClause
+ ):
# assume one to many - FKs are "remote"
if fn(binary.left):
binary.left = binary.left._annotate({"remote": True})
if fn(binary.right) and not equated:
- binary.right = binary.right._annotate(
- {"remote": True})
+ binary.right = binary.right._annotate({"remote": True})
elif not remote_side_given:
self._warn_non_column_elements()
self.primaryjoin = visitors.cloned_traverse(
- self.primaryjoin, {},
- {"binary": visit_binary})
+ self.primaryjoin, {}, {"binary": visit_binary}
+ )
def _annotate_remote_from_args(self):
"""annotate 'remote' in primaryjoin, secondaryjoin
@@ -2463,7 +2618,8 @@ class JoinCondition(object):
raise sa_exc.ArgumentError(
"remote_side argument is redundant "
"against more detailed _local_remote_side "
- "argument.")
+ "argument."
+ )
remote_side = [r for (l, r) in self._local_remote_pairs]
else:
@@ -2472,11 +2628,14 @@ class JoinCondition(object):
if self._refers_to_parent_table():
self._annotate_selfref(lambda col: col in remote_side, True)
else:
+
def repl(element):
if element in remote_side:
return element._annotate({"remote": True})
+
self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, repl)
+ self.primaryjoin, {}, repl
+ )
def _annotate_remote_with_overlap(self):
"""annotate 'remote' in primaryjoin, secondaryjoin
@@ -2485,26 +2644,36 @@ class JoinCondition(object):
relationship.
"""
+
def visit_binary(binary):
- binary.left, binary.right = proc_left_right(binary.left,
- binary.right)
- binary.right, binary.left = proc_left_right(binary.right,
- binary.left)
+ binary.left, binary.right = proc_left_right(
+ binary.left, binary.right
+ )
+ binary.right, binary.left = proc_left_right(
+ binary.right, binary.left
+ )
- check_entities = self.prop is not None and \
- self.prop.mapper is not self.prop.parent
+ check_entities = (
+ self.prop is not None and self.prop.mapper is not self.prop.parent
+ )
def proc_left_right(left, right):
- if isinstance(left, expression.ColumnClause) and \
- isinstance(right, expression.ColumnClause):
- if self.child_selectable.c.contains_column(right) and \
- self.parent_selectable.c.contains_column(left):
+ if isinstance(left, expression.ColumnClause) and isinstance(
+ right, expression.ColumnClause
+ ):
+ if self.child_selectable.c.contains_column(
+ right
+ ) and self.parent_selectable.c.contains_column(left):
right = right._annotate({"remote": True})
- elif check_entities and \
- right._annotations.get('parentmapper') is self.prop.mapper:
+ elif (
+ check_entities
+ and right._annotations.get("parentmapper") is self.prop.mapper
+ ):
right = right._annotate({"remote": True})
- elif check_entities and \
- left._annotations.get('parentmapper') is self.prop.mapper:
+ elif (
+ check_entities
+ and left._annotations.get("parentmapper") is self.prop.mapper
+ ):
left = left._annotate({"remote": True})
else:
self._warn_non_column_elements()
@@ -2512,8 +2681,8 @@ class JoinCondition(object):
return left, right
self.primaryjoin = visitors.cloned_traverse(
- self.primaryjoin, {},
- {"binary": visit_binary})
+ self.primaryjoin, {}, {"binary": visit_binary}
+ )
def _annotate_remote_distinct_selectables(self):
"""annotate 'remote' in primaryjoin, secondaryjoin
@@ -2521,22 +2690,23 @@ class JoinCondition(object):
separate.
"""
+
def repl(element):
- if self.child_selectable.c.contains_column(element) and \
- (not self.parent_local_selectable.c.
- contains_column(element) or
- self.child_local_selectable.c.
- contains_column(element)):
+ if self.child_selectable.c.contains_column(element) and (
+ not self.parent_local_selectable.c.contains_column(element)
+ or self.child_local_selectable.c.contains_column(element)
+ ):
return element._annotate({"remote": True})
+
self.primaryjoin = visitors.replacement_traverse(
- self.primaryjoin, {}, repl)
+ self.primaryjoin, {}, repl
+ )
def _warn_non_column_elements(self):
util.warn(
"Non-simple column elements in primary "
"join condition for property %s - consider using "
- "remote() annotations to mark the remote side."
- % self.prop
+ "remote() annotations to mark the remote side." % self.prop
)
def _annotate_local(self):
@@ -2554,15 +2724,16 @@ class JoinCondition(object):
return
if self._local_remote_pairs:
- local_side = util.column_set([l for (l, r)
- in self._local_remote_pairs])
+ local_side = util.column_set(
+ [l for (l, r) in self._local_remote_pairs]
+ )
else:
local_side = util.column_set(self.parent_selectable.c)
def locals_(elem):
- if "remote" not in elem._annotations and \
- elem in local_side:
+ if "remote" not in elem._annotations and elem in local_side:
return elem._annotate({"local": True})
+
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, locals_
)
@@ -2576,6 +2747,7 @@ class JoinCondition(object):
return elem._annotate({"parentmapper": self.prop.mapper})
elif "local" in elem._annotations:
return elem._annotate({"parentmapper": self.prop.parent})
+
self.primaryjoin = visitors.replacement_traverse(
self.primaryjoin, {}, parentmappers_
)
@@ -2583,14 +2755,15 @@ class JoinCondition(object):
def _check_remote_side(self):
if not self.local_remote_pairs:
raise sa_exc.ArgumentError(
- 'Relationship %s could '
- 'not determine any unambiguous local/remote column '
- 'pairs based on join condition and remote_side '
- 'arguments. '
- 'Consider using the remote() annotation to '
- 'accurately mark those elements of the join '
- 'condition that are on the remote side of '
- 'the relationship.' % (self.prop, ))
+ "Relationship %s could "
+ "not determine any unambiguous local/remote column "
+ "pairs based on join condition and remote_side "
+ "arguments. "
+ "Consider using the remote() annotation to "
+ "accurately mark those elements of the join "
+ "condition that are on the remote side of "
+ "the relationship." % (self.prop,)
+ )
def _check_foreign_cols(self, join_condition, primary):
"""Check the foreign key columns collected and emit error
@@ -2599,7 +2772,8 @@ class JoinCondition(object):
can_sync = False
foreign_cols = self._gather_columns_with_annotation(
- join_condition, "foreign")
+ join_condition, "foreign"
+ )
has_foreign = bool(foreign_cols)
@@ -2608,42 +2782,53 @@ class JoinCondition(object):
else:
can_sync = bool(self.secondary_synchronize_pairs)
- if self.support_sync and can_sync or \
- (not self.support_sync and has_foreign):
+ if (
+ self.support_sync
+ and can_sync
+ or (not self.support_sync and has_foreign)
+ ):
return
# from here below is just determining the best error message
# to report. Check for a join condition using any operator
# (not just ==), perhaps they need to turn on "viewonly=True".
if self.support_sync and has_foreign and not can_sync:
- err = "Could not locate any simple equality expressions "\
- "involving locally mapped foreign key columns for "\
- "%s join condition "\
- "'%s' on relationship %s." % (
- primary and 'primary' or 'secondary',
+ err = (
+ "Could not locate any simple equality expressions "
+ "involving locally mapped foreign key columns for "
+ "%s join condition "
+ "'%s' on relationship %s."
+ % (
+ primary and "primary" or "secondary",
join_condition,
- self.prop
+ self.prop,
)
- err += \
- " Ensure that referencing columns are associated "\
- "with a ForeignKey or ForeignKeyConstraint, or are "\
- "annotated in the join condition with the foreign() "\
- "annotation. To allow comparison operators other than "\
+ )
+ err += (
+ " Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or are "
+ "annotated in the join condition with the foreign() "
+ "annotation. To allow comparison operators other than "
"'==', the relationship can be marked as viewonly=True."
+ )
raise sa_exc.ArgumentError(err)
else:
- err = "Could not locate any relevant foreign key columns "\
- "for %s join condition '%s' on relationship %s." % (
- primary and 'primary' or 'secondary',
+ err = (
+ "Could not locate any relevant foreign key columns "
+ "for %s join condition '%s' on relationship %s."
+ % (
+ primary and "primary" or "secondary",
join_condition,
- self.prop
+ self.prop,
)
- err += \
- ' Ensure that referencing columns are associated '\
- 'with a ForeignKey or ForeignKeyConstraint, or are '\
- 'annotated in the join condition with the foreign() '\
- 'annotation.'
+ )
+ err += (
+ " Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or are "
+ "annotated in the join condition with the foreign() "
+ "annotation."
+ )
raise sa_exc.ArgumentError(err)
def _determine_direction(self):
@@ -2658,13 +2843,11 @@ class JoinCondition(object):
targetcols = util.column_set(self.child_selectable.c)
# fk collection which suggests ONETOMANY.
- onetomany_fk = targetcols.intersection(
- self.foreign_key_columns)
+ onetomany_fk = targetcols.intersection(self.foreign_key_columns)
# fk collection which suggests MANYTOONE.
- manytoone_fk = parentcols.intersection(
- self.foreign_key_columns)
+ manytoone_fk = parentcols.intersection(self.foreign_key_columns)
if onetomany_fk and manytoone_fk:
# fks on both sides. test for overlap of local/remote
@@ -2676,15 +2859,20 @@ class JoinCondition(object):
# 1. columns that are both remote and FK suggest
# onetomany.
onetomany_local = self._gather_columns_with_annotation(
- self.primaryjoin, "remote", "foreign")
+ self.primaryjoin, "remote", "foreign"
+ )
# 2. columns that are FK but are not remote (e.g. local)
# suggest manytoone.
- manytoone_local = set([c for c in
- self._gather_columns_with_annotation(
- self.primaryjoin,
- "foreign")
- if "remote" not in c._annotations])
+ manytoone_local = set(
+ [
+ c
+ for c in self._gather_columns_with_annotation(
+ self.primaryjoin, "foreign"
+ )
+ if "remote" not in c._annotations
+ ]
+ )
# 3. if both collections are present, remove columns that
# refer to themselves. This is for the case of
@@ -2713,7 +2901,8 @@ class JoinCondition(object):
"Ensure that only those columns referring "
"to a parent column are marked as foreign, "
"either via the foreign() annotation or "
- "via the foreign_keys argument." % self.prop)
+ "via the foreign_keys argument." % self.prop
+ )
elif onetomany_fk:
self.direction = ONETOMANY
elif manytoone_fk:
@@ -2723,7 +2912,8 @@ class JoinCondition(object):
"Can't determine relationship "
"direction for relationship '%s' - foreign "
"key columns are present in neither the parent "
- "nor the child's mapped tables" % self.prop)
+ "nor the child's mapped tables" % self.prop
+ )
def _deannotate_pairs(self, collection):
"""provide deannotation for the various lists of
@@ -2732,8 +2922,7 @@ class JoinCondition(object):
original columns mapped.
"""
- return [(x._deannotate(), y._deannotate())
- for x, y in collection]
+ return [(x._deannotate(), y._deannotate()) for x, y in collection]
def _setup_pairs(self):
sync_pairs = []
@@ -2742,25 +2931,31 @@ class JoinCondition(object):
def go(joincond, collection):
def visit_binary(binary, left, right):
- if "remote" in right._annotations and \
- "remote" not in left._annotations and \
- self.can_be_synced_fn(left):
+ if (
+ "remote" in right._annotations
+ and "remote" not in left._annotations
+ and self.can_be_synced_fn(left)
+ ):
lrp.add((left, right))
- elif "remote" in left._annotations and \
- "remote" not in right._annotations and \
- self.can_be_synced_fn(right):
+ elif (
+ "remote" in left._annotations
+ and "remote" not in right._annotations
+ and self.can_be_synced_fn(right)
+ ):
lrp.add((right, left))
- if binary.operator is operators.eq and \
- self.can_be_synced_fn(left, right):
+ if binary.operator is operators.eq and self.can_be_synced_fn(
+ left, right
+ ):
if "foreign" in right._annotations:
collection.append((left, right))
elif "foreign" in left._annotations:
collection.append((right, left))
+
visit_binary_product(visit_binary, joincond)
for joincond, collection in [
(self.primaryjoin, sync_pairs),
- (self.secondaryjoin, secondary_sync_pairs)
+ (self.secondaryjoin, secondary_sync_pairs),
]:
if joincond is None:
continue
@@ -2768,8 +2963,9 @@ class JoinCondition(object):
self.local_remote_pairs = self._deannotate_pairs(lrp)
self.synchronize_pairs = self._deannotate_pairs(sync_pairs)
- self.secondary_synchronize_pairs = \
- self._deannotate_pairs(secondary_sync_pairs)
+ self.secondary_synchronize_pairs = self._deannotate_pairs(
+ secondary_sync_pairs
+ )
_track_overlapping_sync_targets = weakref.WeakKeyDictionary()
@@ -2797,20 +2993,23 @@ class JoinCondition(object):
continue
if to_ not in self._track_overlapping_sync_targets:
- self._track_overlapping_sync_targets[to_] = \
- weakref.WeakKeyDictionary({self.prop: from_})
+ self._track_overlapping_sync_targets[
+ to_
+ ] = weakref.WeakKeyDictionary({self.prop: from_})
else:
other_props = []
prop_to_from = self._track_overlapping_sync_targets[to_]
for pr, fr_ in prop_to_from.items():
- if pr.mapper in mapperlib._mapper_registry and \
- (
- self.prop._persists_for(pr.parent) or
- pr._persists_for(self.prop.parent)
- ) and \
- fr_ is not from_ and \
- pr not in self.prop._reverse_property:
+ if (
+ pr.mapper in mapperlib._mapper_registry
+ and (
+ self.prop._persists_for(pr.parent)
+ or pr._persists_for(self.prop.parent)
+ )
+ and fr_ is not from_
+ and pr not in self.prop._reverse_property
+ ):
other_props.append((pr, fr_))
@@ -2821,12 +3020,15 @@ class JoinCondition(object):
"Consider applying "
"viewonly=True to read-only relationships, or provide "
"a primaryjoin condition marking writable columns "
- "with the foreign() annotation." % (
+ "with the foreign() annotation."
+ % (
self.prop,
- from_, to_,
+ from_,
+ to_,
", ".join(
"'%s' (copies %s to %s)" % (pr, fr_, to_)
- for (pr, fr_) in other_props)
+ for (pr, fr_) in other_props
+ ),
)
)
self._track_overlapping_sync_targets[to_][self.prop] = from_
@@ -2845,27 +3047,29 @@ class JoinCondition(object):
def _gather_join_annotations(self, annotation):
s = set(
- self._gather_columns_with_annotation(
- self.primaryjoin, annotation)
+ self._gather_columns_with_annotation(self.primaryjoin, annotation)
)
if self.secondaryjoin is not None:
s.update(
self._gather_columns_with_annotation(
- self.secondaryjoin, annotation)
+ self.secondaryjoin, annotation
+ )
)
return {x._deannotate() for x in s}
def _gather_columns_with_annotation(self, clause, *annotation):
annotation = set(annotation)
- return set([
- col for col in visitors.iterate(clause, {})
- if annotation.issubset(col._annotations)
- ])
-
- def join_targets(self, source_selectable,
- dest_selectable,
- aliased,
- single_crit=None):
+ return set(
+ [
+ col
+ for col in visitors.iterate(clause, {})
+ if annotation.issubset(col._annotations)
+ ]
+ )
+
+ def join_targets(
+ self, source_selectable, dest_selectable, aliased, single_crit=None
+ ):
"""Given a source and destination selectable, create a
join between them.
@@ -2881,11 +3085,14 @@ class JoinCondition(object):
# its internal structure remains fixed
# regardless of context.
dest_selectable = _shallow_annotate(
- dest_selectable,
- {'no_replacement_traverse': True})
+ dest_selectable, {"no_replacement_traverse": True}
+ )
- primaryjoin, secondaryjoin, secondary = self.primaryjoin, \
- self.secondaryjoin, self.secondary
+ primaryjoin, secondaryjoin, secondary = (
+ self.primaryjoin,
+ self.secondaryjoin,
+ self.secondary,
+ )
# adjust the join condition for single table inheritance,
# in the case that the join is to a subclass
@@ -2902,28 +3109,31 @@ class JoinCondition(object):
if secondary is not None:
secondary = secondary.alias(flat=True)
primary_aliasizer = ClauseAdapter(secondary)
- secondary_aliasizer = \
- ClauseAdapter(dest_selectable,
- equivalents=self.child_equivalents).\
- chain(primary_aliasizer)
+ secondary_aliasizer = ClauseAdapter(
+ dest_selectable, equivalents=self.child_equivalents
+ ).chain(primary_aliasizer)
if source_selectable is not None:
- primary_aliasizer = \
- ClauseAdapter(secondary).\
- chain(ClauseAdapter(
+ primary_aliasizer = ClauseAdapter(secondary).chain(
+ ClauseAdapter(
source_selectable,
- equivalents=self.parent_equivalents))
- secondaryjoin = \
- secondary_aliasizer.traverse(secondaryjoin)
+ equivalents=self.parent_equivalents,
+ )
+ )
+ secondaryjoin = secondary_aliasizer.traverse(secondaryjoin)
else:
primary_aliasizer = ClauseAdapter(
dest_selectable,
exclude_fn=_ColInAnnotations("local"),
- equivalents=self.child_equivalents)
+ equivalents=self.child_equivalents,
+ )
if source_selectable is not None:
primary_aliasizer.chain(
- ClauseAdapter(source_selectable,
- exclude_fn=_ColInAnnotations("remote"),
- equivalents=self.parent_equivalents))
+ ClauseAdapter(
+ source_selectable,
+ exclude_fn=_ColInAnnotations("remote"),
+ equivalents=self.parent_equivalents,
+ )
+ )
secondary_aliasizer = None
primaryjoin = primary_aliasizer.traverse(primaryjoin)
@@ -2931,8 +3141,13 @@ class JoinCondition(object):
target_adapter.exclude_fn = None
else:
target_adapter = None
- return primaryjoin, secondaryjoin, secondary, \
- target_adapter, dest_selectable
+ return (
+ primaryjoin,
+ secondaryjoin,
+ secondary,
+ target_adapter,
+ dest_selectable,
+ )
def create_lazy_clause(self, reverse_direction=False):
binds = util.column_dict()
@@ -2955,28 +3170,32 @@ class JoinCondition(object):
def col_to_bind(col):
if (
- (not reverse_direction and 'local' in col._annotations) or
- reverse_direction and (
- (has_secondary and col in lookup) or
- (not has_secondary and 'remote' in col._annotations)
+ (not reverse_direction and "local" in col._annotations)
+ or reverse_direction
+ and (
+ (has_secondary and col in lookup)
+ or (not has_secondary and "remote" in col._annotations)
)
):
if col not in binds:
binds[col] = sql.bindparam(
- None, None, type_=col.type, unique=True)
+ None, None, type_=col.type, unique=True
+ )
return binds[col]
return None
lazywhere = self.primaryjoin
if self.secondaryjoin is None or not reverse_direction:
lazywhere = visitors.replacement_traverse(
- lazywhere, {}, col_to_bind)
+ lazywhere, {}, col_to_bind
+ )
if self.secondaryjoin is not None:
secondaryjoin = self.secondaryjoin
if reverse_direction:
secondaryjoin = visitors.replacement_traverse(
- secondaryjoin, {}, col_to_bind)
+ secondaryjoin, {}, col_to_bind
+ )
lazywhere = sql.and_(lazywhere, secondaryjoin)
bind_to_col = {binds[col].key: col for col in binds}
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index 2e16872f9..2eeaf5b6d 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -11,7 +11,7 @@ from . import class_mapper, exc as orm_exc
from .session import Session
-__all__ = ['scoped_session']
+__all__ = ["scoped_session"]
class scoped_session(object):
@@ -65,7 +65,8 @@ class scoped_session(object):
if self.registry.has():
raise sa_exc.InvalidRequestError(
"Scoped session is already present; "
- "no new arguments may be specified.")
+ "no new arguments may be specified."
+ )
else:
sess = self.session_factory(**kw)
self.registry.set(sess)
@@ -99,9 +100,11 @@ class scoped_session(object):
"""
if self.registry.has():
- warn('At least one scoped session is already present. '
- ' configure() can not affect sessions that have '
- 'already been created.')
+ warn(
+ "At least one scoped session is already present. "
+ " configure() can not affect sessions that have "
+ "already been created."
+ )
self.session_factory.configure(**kwargs)
@@ -129,6 +132,7 @@ class scoped_session(object):
a class.
"""
+
class query(object):
def __get__(s, instance, owner):
try:
@@ -142,8 +146,10 @@ class scoped_session(object):
return self.registry().query(mapper)
except orm_exc.UnmappedClassError:
return None
+
return query()
+
ScopedSession = scoped_session
"""Old name for backwards compatibility."""
@@ -151,8 +157,10 @@ ScopedSession = scoped_session
def instrument(name):
def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs)
+
return do
+
for meth in Session.public_methods:
setattr(scoped_session, meth, instrument(meth))
@@ -166,16 +174,28 @@ def makeprop(name):
return property(get, set)
-for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map',
- 'is_active', 'autoflush', 'no_autoflush', 'info',
- 'autocommit'):
+
+for prop in (
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ "autocommit",
+):
setattr(scoped_session, prop, makeprop(prop))
def clslevel(name):
def do(cls, *args, **kwargs):
return getattr(Session, name)(*args, **kwargs)
+
return classmethod(do)
-for prop in ('close_all', 'object_session', 'identity_key'):
+
+for prop in ("close_all", "object_session", "identity_key"):
setattr(scoped_session, prop, clslevel(prop))
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index b1993118d..a3edacc19 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -10,15 +10,17 @@
import weakref
from .. import util, sql, engine, exc as sa_exc
from ..sql import util as sql_util, expression
-from . import (
- SessionExtension, attributes, exc, query,
- loading, identity
-)
+from . import SessionExtension, attributes, exc, query, loading, identity
from ..inspection import inspect
from .base import (
- object_mapper, class_mapper,
- _class_to_mapper, _state_mapper, object_state,
- _none_set, state_str, instance_str
+ object_mapper,
+ class_mapper,
+ _class_to_mapper,
+ _state_mapper,
+ object_state,
+ _none_set,
+ state_str,
+ instance_str,
)
import itertools
from . import persistence
@@ -26,8 +28,7 @@ from .unitofwork import UOWTransaction
from . import state as statelib
import sys
-__all__ = ['Session', 'SessionTransaction',
- 'SessionExtension', 'sessionmaker']
+__all__ = ["Session", "SessionTransaction", "SessionExtension", "sessionmaker"]
_sessions = weakref.WeakValueDictionary()
"""Weak-referencing dictionary of :class:`.Session` objects.
@@ -77,11 +78,11 @@ class _SessionClassMethods(object):
return object_session(instance)
-ACTIVE = util.symbol('ACTIVE')
-PREPARED = util.symbol('PREPARED')
-COMMITTED = util.symbol('COMMITTED')
-DEACTIVE = util.symbol('DEACTIVE')
-CLOSED = util.symbol('CLOSED')
+ACTIVE = util.symbol("ACTIVE")
+PREPARED = util.symbol("PREPARED")
+COMMITTED = util.symbol("COMMITTED")
+DEACTIVE = util.symbol("DEACTIVE")
+CLOSED = util.symbol("CLOSED")
class SessionTransaction(object):
@@ -212,7 +213,8 @@ class SessionTransaction(object):
if not parent and nested:
raise sa_exc.InvalidRequestError(
"Can't start a SAVEPOINT transaction when no existing "
- "transaction is in progress")
+ "transaction is in progress"
+ )
if self.session._enable_transaction_accounting:
self._take_snapshot()
@@ -249,10 +251,13 @@ class SessionTransaction(object):
def is_active(self):
return self.session is not None and self._state is ACTIVE
- def _assert_active(self, prepared_ok=False,
- rollback_ok=False,
- deactive_ok=False,
- closed_msg="This transaction is closed"):
+ def _assert_active(
+ self,
+ prepared_ok=False,
+ rollback_ok=False,
+ deactive_ok=False,
+ closed_msg="This transaction is closed",
+ ):
if self._state is COMMITTED:
raise sa_exc.InvalidRequestError(
"This session is in 'committed' state; no further "
@@ -295,21 +300,21 @@ class SessionTransaction(object):
def _begin(self, nested=False):
self._assert_active()
- return SessionTransaction(
- self.session, self, nested=nested)
+ return SessionTransaction(self.session, self, nested=nested)
def _iterate_self_and_parents(self, upto=None):
current = self
result = ()
while current:
- result += (current, )
+ result += (current,)
if current._parent is upto:
break
elif current._parent is None:
raise sa_exc.InvalidRequestError(
- "Transaction %s is not on the active transaction list" % (
- upto))
+ "Transaction %s is not on the active transaction list"
+ % (upto)
+ )
else:
current = current._parent
@@ -376,7 +381,8 @@ class SessionTransaction(object):
s._expire(s.dict, self.session.identity_map._modified)
statelib.InstanceState._detach_states(
- list(self._deleted), self.session)
+ list(self._deleted), self.session
+ )
self._deleted.clear()
elif self.nested:
self._parent._new.update(self._new)
@@ -391,7 +397,8 @@ class SessionTransaction(object):
if execution_options:
util.warn(
"Connection is already established for the "
- "given bind; execution_options ignored")
+ "given bind; execution_options ignored"
+ )
return self._connections[bind][0]
if self._parent:
@@ -404,7 +411,8 @@ class SessionTransaction(object):
if conn.engine in self._connections:
raise sa_exc.InvalidRequestError(
"Session already has a Connection associated for the "
- "given Connection's Engine")
+ "given Connection's Engine"
+ )
else:
conn = bind.contextual_connect()
@@ -418,8 +426,11 @@ class SessionTransaction(object):
else:
transaction = conn.begin()
- self._connections[conn] = self._connections[conn.engine] = \
- (conn, transaction, conn is not bind)
+ self._connections[conn] = self._connections[conn.engine] = (
+ conn,
+ transaction,
+ conn is not bind,
+ )
self.session.dispatch.after_begin(self.session, self, conn)
return conn
@@ -427,7 +438,8 @@ class SessionTransaction(object):
if self._parent is not None or not self.session.twophase:
raise sa_exc.InvalidRequestError(
"'twophase' mode not enabled, or not root transaction; "
- "can't prepare.")
+ "can't prepare."
+ )
self._prepare_impl()
def _prepare_impl(self):
@@ -449,7 +461,8 @@ class SessionTransaction(object):
raise exc.FlushError(
"Over 100 subsequent flushes have occurred within "
"session.commit() - is an after_flush() hook "
- "creating new objects?")
+ "creating new objects?"
+ )
if self._parent is None and self.session.twophase:
try:
@@ -504,7 +517,8 @@ class SessionTransaction(object):
transaction._state = DEACTIVE
if self.session._enable_transaction_accounting:
transaction._restore_snapshot(
- dirty_only=transaction.nested)
+ dirty_only=transaction.nested
+ )
boundary = transaction
break
else:
@@ -512,15 +526,19 @@ class SessionTransaction(object):
sess = self.session
- if not rollback_err and sess._enable_transaction_accounting and \
- not sess._is_clean():
+ if (
+ not rollback_err
+ and sess._enable_transaction_accounting
+ and not sess._is_clean()
+ ):
# if items were added, deleted, or mutated
# here, we need to re-restore the snapshot
util.warn(
"Session's state has been changed on "
"a non-active transaction - this state "
- "will be discarded.")
+ "will be discarded."
+ )
boundary._restore_snapshot(dirty_only=boundary.nested)
self.close()
@@ -535,12 +553,12 @@ class SessionTransaction(object):
return self._parent
-
def close(self, invalidate=False):
self.session.transaction = self._parent
if self._parent is None:
- for connection, transaction, autoclose in \
- set(self._connections.values()):
+ for connection, transaction, autoclose in set(
+ self._connections.values()
+ ):
if invalidate:
connection.invalidate()
if autoclose:
@@ -583,21 +601,49 @@ class Session(_SessionClassMethods):
"""
public_methods = (
- '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested',
- 'close', 'commit', 'connection', 'delete', 'execute', 'expire',
- 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind',
- 'is_modified', 'bulk_save_objects', 'bulk_insert_mappings',
- 'bulk_update_mappings',
- 'merge', 'query', 'refresh', 'rollback',
- 'scalar')
-
- def __init__(self, bind=None, autoflush=True, expire_on_commit=True,
- _enable_transaction_accounting=True,
- autocommit=False, twophase=False,
- weak_identity_map=True, binds=None, extension=None,
- enable_baked_queries=True,
- info=None,
- query_cls=query.Query):
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get_bind",
+ "is_modified",
+ "bulk_save_objects",
+ "bulk_insert_mappings",
+ "bulk_update_mappings",
+ "merge",
+ "query",
+ "refresh",
+ "rollback",
+ "scalar",
+ )
+
+ def __init__(
+ self,
+ bind=None,
+ autoflush=True,
+ expire_on_commit=True,
+ _enable_transaction_accounting=True,
+ autocommit=False,
+ twophase=False,
+ weak_identity_map=True,
+ binds=None,
+ extension=None,
+ enable_baked_queries=True,
+ info=None,
+ query_cls=query.Query,
+ ):
r"""Construct a new Session.
See also the :class:`.sessionmaker` function which is used to
@@ -753,12 +799,13 @@ class Session(_SessionClassMethods):
"weak_identity_map=False is deprecated. "
"See the documentation on 'Session Referencing Behavior' "
"for an event-based approach to maintaining strong identity "
- "references.")
+ "references."
+ )
self._identity_cls = identity.StrongInstanceDict
self.identity_map = self._identity_cls()
- self._new = {} # InstanceState->object, strong refs object
+ self._new = {} # InstanceState->object, strong refs object
self._deleted = {} # same
self.bind = bind
self.__binds = {}
@@ -861,15 +908,14 @@ class Session(_SessionClassMethods):
"""
if self.transaction is not None:
if subtransactions or nested:
- self.transaction = self.transaction._begin(
- nested=nested)
+ self.transaction = self.transaction._begin(nested=nested)
else:
raise sa_exc.InvalidRequestError(
"A transaction is already begun. Use "
- "subtransactions=True to allow subtransactions.")
+ "subtransactions=True to allow subtransactions."
+ )
else:
- self.transaction = SessionTransaction(
- self, nested=nested)
+ self.transaction = SessionTransaction(self, nested=nested)
return self.transaction # needed for __enter__/__exit__ hook
def begin_nested(self):
@@ -972,11 +1018,15 @@ class Session(_SessionClassMethods):
self.transaction.prepare()
- def connection(self, mapper=None, clause=None,
- bind=None,
- close_with_result=False,
- execution_options=None,
- **kw):
+ def connection(
+ self,
+ mapper=None,
+ clause=None,
+ bind=None,
+ close_with_result=False,
+ execution_options=None,
+ **kw
+ ):
r"""Return a :class:`.Connection` object corresponding to this
:class:`.Session` object's transactional state.
@@ -1041,14 +1091,17 @@ class Session(_SessionClassMethods):
if bind is None:
bind = self.get_bind(mapper, clause=clause, **kw)
- return self._connection_for_bind(bind,
- close_with_result=close_with_result,
- execution_options=execution_options)
+ return self._connection_for_bind(
+ bind,
+ close_with_result=close_with_result,
+ execution_options=execution_options,
+ )
def _connection_for_bind(self, engine, execution_options=None, **kw):
if self.transaction is not None:
return self.transaction._connection_for_bind(
- engine, execution_options)
+ engine, execution_options
+ )
else:
conn = engine.contextual_connect(**kw)
if execution_options:
@@ -1183,14 +1236,16 @@ class Session(_SessionClassMethods):
if bind is None:
bind = self.get_bind(mapper, clause=clause, **kw)
- return self._connection_for_bind(
- bind, close_with_result=True).execute(clause, params or {})
+ return self._connection_for_bind(bind, close_with_result=True).execute(
+ clause, params or {}
+ )
def scalar(self, clause, params=None, mapper=None, bind=None, **kw):
"""Like :meth:`~.Session.execute` but return a scalar result."""
return self.execute(
- clause, params=params, mapper=mapper, bind=bind, **kw).scalar()
+ clause, params=params, mapper=mapper, bind=bind, **kw
+ ).scalar()
def close(self):
"""Close this Session.
@@ -1256,9 +1311,7 @@ class Session(_SessionClassMethods):
self._new = {}
self._deleted = {}
- statelib.InstanceState._detach_states(
- all_states, self
- )
+ statelib.InstanceState._detach_states(all_states, self)
def _add_bind(self, key, bind):
try:
@@ -1266,7 +1319,8 @@ class Session(_SessionClassMethods):
except sa_exc.NoInspectionAvailable:
if not isinstance(key, type):
raise sa_exc.ArgumentError(
- "Not an acceptable bind target: %s" % key)
+ "Not an acceptable bind target: %s" % key
+ )
else:
self.__binds[key] = bind
else:
@@ -1278,7 +1332,8 @@ class Session(_SessionClassMethods):
self.__binds[selectable] = bind
else:
raise sa_exc.ArgumentError(
- "Not an acceptable bind target: %s" % key)
+ "Not an acceptable bind target: %s" % key
+ )
def bind_mapper(self, mapper, bind):
"""Associate a :class:`.Mapper` or arbitrary Python class with a
@@ -1408,7 +1463,8 @@ class Session(_SessionClassMethods):
raise sa_exc.UnboundExecutionError(
"This session is not bound to a single Engine or "
"Connection, and no context was provided to locate "
- "a binding.")
+ "a binding."
+ )
if mapper is not None:
try:
@@ -1443,13 +1499,14 @@ class Session(_SessionClassMethods):
context = []
if mapper is not None:
- context.append('mapper %s' % mapper)
+ context.append("mapper %s" % mapper)
if clause is not None:
- context.append('SQL expression')
+ context.append("SQL expression")
raise sa_exc.UnboundExecutionError(
- "Could not locate a bind configured on %s or this Session" % (
- ', '.join(context)))
+ "Could not locate a bind configured on %s or this Session"
+ % (", ".join(context))
+ )
def query(self, *entities, **kwargs):
"""Return a new :class:`.Query` object corresponding to this
@@ -1499,12 +1556,17 @@ class Session(_SessionClassMethods):
e.add_detail(
"raised as a result of Query-invoked autoflush; "
"consider using a session.no_autoflush block if this "
- "flush is occurring prematurely")
+ "flush is occurring prematurely"
+ )
util.raise_from_cause(e)
def refresh(
- self, instance, attribute_names=None, with_for_update=None,
- lockmode=None):
+ self,
+ instance,
+ attribute_names=None,
+ with_for_update=None,
+ lockmode=None,
+ ):
"""Expire and refresh the attributes on the given instance.
A query will be issued to the database and all attributes will be
@@ -1560,7 +1622,8 @@ class Session(_SessionClassMethods):
raise sa_exc.ArgumentError(
"with_for_update should be the boolean value "
"True, or a dictionary with options. "
- "A blank dictionary is ambiguous.")
+ "A blank dictionary is ambiguous."
+ )
if lockmode:
with_for_update = query.LockmodeArg.parse_legacy_query(lockmode)
@@ -1572,14 +1635,19 @@ class Session(_SessionClassMethods):
else:
with_for_update = None
- if loading.load_on_ident(
+ if (
+ loading.load_on_ident(
self.query(object_mapper(instance)),
- state.key, refresh_state=state,
+ state.key,
+ refresh_state=state,
with_for_update=with_for_update,
- only_load_props=attribute_names) is None:
+ only_load_props=attribute_names,
+ )
+ is None
+ ):
raise sa_exc.InvalidRequestError(
- "Could not refresh instance '%s'" %
- instance_str(instance))
+ "Could not refresh instance '%s'" % instance_str(instance)
+ )
def expire_all(self):
"""Expires all persistent instances within this Session.
@@ -1662,8 +1730,9 @@ class Session(_SessionClassMethods):
else:
# pre-fetch the full cascade since the expire is going to
# remove associations
- cascaded = list(state.manager.mapper.cascade_iterator(
- 'refresh-expire', state))
+ cascaded = list(
+ state.manager.mapper.cascade_iterator("refresh-expire", state)
+ )
self._conditional_expire(state)
for o, m, st_, dct_ in cascaded:
self._conditional_expire(st_)
@@ -1677,8 +1746,11 @@ class Session(_SessionClassMethods):
self._new.pop(state)
state._detach(self)
- @util.deprecated("0.7", "The non-weak-referencing identity map "
- "feature is no longer needed.")
+ @util.deprecated(
+ "0.7",
+ "The non-weak-referencing identity map "
+ "feature is no longer needed.",
+ )
def prune(self):
"""Remove unreferenced instances cached in the identity map.
@@ -1705,14 +1777,13 @@ class Session(_SessionClassMethods):
raise exc.UnmappedInstanceError(instance)
if state.session_id is not self.hash_key:
raise sa_exc.InvalidRequestError(
- "Instance %s is not present in this Session" %
- state_str(state))
+ "Instance %s is not present in this Session" % state_str(state)
+ )
- cascaded = list(state.manager.mapper.cascade_iterator(
- 'expunge', state))
- self._expunge_states(
- [state] + [st_ for o, m, st_, dct_ in cascaded]
+ cascaded = list(
+ state.manager.mapper.cascade_iterator("expunge", state)
)
+ self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded])
def _expunge_states(self, states, to_transient=False):
for state in states:
@@ -1726,7 +1797,8 @@ class Session(_SessionClassMethods):
# in the transaction snapshot
self.transaction._deleted.pop(state, None)
statelib.InstanceState._detach_states(
- states, self, to_transient=to_transient)
+ states, self, to_transient=to_transient
+ )
def _register_newly_persistent(self, states):
pending_to_persistent = self.dispatch.pending_to_persistent or None
@@ -1739,9 +1811,11 @@ class Session(_SessionClassMethods):
instance_key = mapper._identity_key_from_state(state)
- if _none_set.intersection(instance_key[1]) and \
- not mapper.allow_partial_pks or \
- _none_set.issuperset(instance_key[1]):
+ if (
+ _none_set.intersection(instance_key[1])
+ and not mapper.allow_partial_pks
+ or _none_set.issuperset(instance_key[1])
+ ):
raise exc.FlushError(
"Instance %s has a NULL identity key. If this is an "
"auto-generated value, check that the database table "
@@ -1765,15 +1839,16 @@ class Session(_SessionClassMethods):
else:
orig_key = state.key
self.transaction._key_switches[state] = (
- orig_key, instance_key)
+ orig_key,
+ instance_key,
+ )
state.key = instance_key
self.identity_map.replace(state)
state._orphaned_outside_of_session = False
statelib.InstanceState._commit_all_states(
- ((state, state.dict) for state in states),
- self.identity_map
+ ((state, state.dict) for state in states), self.identity_map
)
self._register_altered(states)
@@ -1849,9 +1924,8 @@ class Session(_SessionClassMethods):
mapper = _state_mapper(state)
for o, m, st_, dct_ in mapper.cascade_iterator(
- 'save-update',
- state,
- halt_on=self._contains_state):
+ "save-update", state, halt_on=self._contains_state
+ ):
self._save_or_update_impl(st_)
def delete(self, instance):
@@ -1875,8 +1949,8 @@ class Session(_SessionClassMethods):
if state.key is None:
if head:
raise sa_exc.InvalidRequestError(
- "Instance '%s' is not persisted" %
- state_str(state))
+ "Instance '%s' is not persisted" % state_str(state)
+ )
else:
return
@@ -1894,8 +1968,9 @@ class Session(_SessionClassMethods):
# grab the cascades before adding the item to the deleted list
# so that autoflush does not delete the item
# the strong reference to the instance itself is significant here
- cascade_states = list(state.manager.mapper.cascade_iterator(
- 'delete', state))
+ cascade_states = list(
+ state.manager.mapper.cascade_iterator("delete", state)
+ )
self._deleted[state] = obj
@@ -1975,13 +2050,21 @@ class Session(_SessionClassMethods):
return self._merge(
attributes.instance_state(instance),
attributes.instance_dict(instance),
- load=load, _recursive=_recursive,
- _resolve_conflict_map=_resolve_conflict_map)
+ load=load,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
finally:
self.autoflush = autoflush
- def _merge(self, state, state_dict, load=True, _recursive=None,
- _resolve_conflict_map=None):
+ def _merge(
+ self,
+ state,
+ state_dict,
+ load=True,
+ _recursive=None,
+ _resolve_conflict_map=None,
+ ):
mapper = _state_mapper(state)
if state in _recursive:
return _recursive[state]
@@ -1995,11 +2078,15 @@ class Session(_SessionClassMethods):
"merge() with load=False option does not support "
"objects transient (i.e. unpersisted) objects. flush() "
"all changes on mapped instances before merging with "
- "load=False.")
+ "load=False."
+ )
key = mapper._identity_key_from_state(state)
key_is_persistent = attributes.NEVER_SET not in key[1] and (
- not _none_set.intersection(key[1]) or
- (mapper.allow_partial_pks and not _none_set.issuperset(key[1]))
+ not _none_set.intersection(key[1])
+ or (
+ mapper.allow_partial_pks
+ and not _none_set.issuperset(key[1])
+ )
)
else:
key_is_persistent = True
@@ -2022,7 +2109,8 @@ class Session(_SessionClassMethods):
raise sa_exc.InvalidRequestError(
"merge() with load=False option does not support "
"objects marked as 'dirty'. flush() all changes on "
- "mapped instances before merging with load=False.")
+ "mapped instances before merging with load=False."
+ )
merged = mapper.class_manager.new_instance()
merged_state = attributes.instance_state(merged)
merged_state.key = key
@@ -2054,17 +2142,21 @@ class Session(_SessionClassMethods):
state,
state_dict,
mapper.version_id_col,
- passive=attributes.PASSIVE_NO_INITIALIZE)
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
merged_version = mapper._get_state_attr_by_column(
merged_state,
merged_dict,
mapper.version_id_col,
- passive=attributes.PASSIVE_NO_INITIALIZE)
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
- if existing_version is not attributes.PASSIVE_NO_RESULT and \
- merged_version is not attributes.PASSIVE_NO_RESULT and \
- existing_version != merged_version:
+ if (
+ existing_version is not attributes.PASSIVE_NO_RESULT
+ and merged_version is not attributes.PASSIVE_NO_RESULT
+ and existing_version != merged_version
+ ):
raise exc.StaleDataError(
"Version id '%s' on merged state %s "
"does not match existing version '%s'. "
@@ -2073,8 +2165,9 @@ class Session(_SessionClassMethods):
% (
existing_version,
state_str(merged_state),
- merged_version
- ))
+ merged_version,
+ )
+ )
merged_state.load_path = state.load_path
merged_state.load_options = state.load_options
@@ -2087,9 +2180,16 @@ class Session(_SessionClassMethods):
merged_state._copy_callables(state)
for prop in mapper.iterate_properties:
- prop.merge(self, state, state_dict,
- merged_state, merged_dict,
- load, _recursive, _resolve_conflict_map)
+ prop.merge(
+ self,
+ state,
+ state_dict,
+ merged_state,
+ merged_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ )
if not load:
# remove any history
@@ -2102,14 +2202,16 @@ class Session(_SessionClassMethods):
def _validate_persistent(self, state):
if not self.identity_map.contains_state(state):
raise sa_exc.InvalidRequestError(
- "Instance '%s' is not persistent within this Session" %
- state_str(state))
+ "Instance '%s' is not persistent within this Session"
+ % state_str(state)
+ )
def _save_impl(self, state):
if state.key is not None:
raise sa_exc.InvalidRequestError(
"Object '%s' already has an identity - "
- "it can't be registered as pending" % state_str(state))
+ "it can't be registered as pending" % state_str(state)
+ )
obj = state.obj()
to_attach = self._before_attach(state, obj)
@@ -2122,8 +2224,8 @@ class Session(_SessionClassMethods):
def _update_impl(self, state, revert_deletion=False):
if state.key is None:
raise sa_exc.InvalidRequestError(
- "Instance '%s' is not persisted" %
- state_str(state))
+ "Instance '%s' is not persisted" % state_str(state)
+ )
if state._deleted:
if revert_deletion:
@@ -2135,8 +2237,7 @@ class Session(_SessionClassMethods):
"Instance '%s' has been deleted. "
"Use the make_transient() "
"function to send this object back "
- "to the transient state." %
- state_str(state)
+ "to the transient state." % state_str(state)
)
obj = state.obj()
@@ -2234,8 +2335,9 @@ class Session(_SessionClassMethods):
if state.session_id and state.session_id in _sessions:
raise sa_exc.InvalidRequestError(
"Object '%s' is already attached to session '%s' "
- "(this is '%s')" % (state_str(state),
- state.session_id, self.hash_key))
+ "(this is '%s')"
+ % (state_str(state), state.session_id, self.hash_key)
+ )
self.dispatch.before_attach(self, obj)
@@ -2271,7 +2373,8 @@ class Session(_SessionClassMethods):
"""
return iter(
- list(self._new.values()) + list(self.identity_map.values()))
+ list(self._new.values()) + list(self.identity_map.values())
+ )
def _contains_state(self, state):
return state in self._new or self.identity_map.contains_state(state)
@@ -2319,13 +2422,15 @@ class Session(_SessionClassMethods):
"Usage of the '%s' operation is not currently supported "
"within the execution stage of the flush process. "
"Results may not be consistent. Consider using alternative "
- "event listeners or connection-level operations instead."
- % method)
+ "event listeners or connection-level operations instead." % method
+ )
def _is_clean(self):
- return not self.identity_map.check_modified() and \
- not self._deleted and \
- not self._new
+ return (
+ not self.identity_map.check_modified()
+ and not self._deleted
+ and not self._new
+ )
def _flush(self, objects=None):
@@ -2375,12 +2480,16 @@ class Session(_SessionClassMethods):
is_persistent_orphan = is_orphan and state.has_identity
- if is_orphan and not is_persistent_orphan and \
- state._orphaned_outside_of_session:
+ if (
+ is_orphan
+ and not is_persistent_orphan
+ and state._orphaned_outside_of_session
+ ):
self._expunge_states([state])
else:
_reg = flush_context.register_object(
- state, isdelete=is_persistent_orphan)
+ state, isdelete=is_persistent_orphan
+ )
assert _reg, "Failed to add object to the flush context!"
processed.add(state)
@@ -2397,7 +2506,8 @@ class Session(_SessionClassMethods):
return
flush_context.transaction = transaction = self.begin(
- subtransactions=True)
+ subtransactions=True
+ )
try:
self._warn_on_events = True
try:
@@ -2413,16 +2523,20 @@ class Session(_SessionClassMethods):
len_ = len(self.identity_map._modified)
statelib.InstanceState._commit_all_states(
- [(state, state.dict) for state in
- self.identity_map._modified],
- instance_dict=self.identity_map)
- util.warn("Attribute history events accumulated on %d "
- "previously clean instances "
- "within inner-flush event handlers have been "
- "reset, and will not result in database updates. "
- "Consider using set_committed_value() within "
- "inner-flush event handlers to avoid this warning."
- % len_)
+ [
+ (state, state.dict)
+ for state in self.identity_map._modified
+ ],
+ instance_dict=self.identity_map,
+ )
+ util.warn(
+ "Attribute history events accumulated on %d "
+ "previously clean instances "
+ "within inner-flush event handlers have been "
+ "reset, and will not result in database updates. "
+ "Consider using set_committed_value() within "
+ "inner-flush event handlers to avoid this warning." % len_
+ )
# useful assertions:
# if not objects:
@@ -2440,8 +2554,12 @@ class Session(_SessionClassMethods):
transaction.rollback(_capture_exception=True)
def bulk_save_objects(
- self, objects, return_defaults=False, update_changed_only=True,
- preserve_order=True):
+ self,
+ objects,
+ return_defaults=False,
+ update_changed_only=True,
+ preserve_order=True,
+ ):
"""Perform a bulk save of the given list of objects.
The bulk save feature allows mapped objects to be used as the
@@ -2520,6 +2638,7 @@ class Session(_SessionClassMethods):
:meth:`.Session.bulk_update_mappings`
"""
+
def key(state):
return (state.mapper, state.key is not None)
@@ -2527,15 +2646,20 @@ class Session(_SessionClassMethods):
if not preserve_order:
obj_states = sorted(obj_states, key=key)
- for (mapper, isupdate), states in itertools.groupby(
- obj_states, key
- ):
+ for (mapper, isupdate), states in itertools.groupby(obj_states, key):
self._bulk_save_mappings(
- mapper, states, isupdate, True,
- return_defaults, update_changed_only, False)
+ mapper,
+ states,
+ isupdate,
+ True,
+ return_defaults,
+ update_changed_only,
+ False,
+ )
def bulk_insert_mappings(
- self, mapper, mappings, return_defaults=False, render_nulls=False):
+ self, mapper, mappings, return_defaults=False, render_nulls=False
+ ):
"""Perform a bulk insert of the given list of mapping dictionaries.
The bulk insert feature allows plain Python dictionaries to be used as
@@ -2622,8 +2746,14 @@ class Session(_SessionClassMethods):
"""
self._bulk_save_mappings(
- mapper, mappings, False, False,
- return_defaults, False, render_nulls)
+ mapper,
+ mappings,
+ False,
+ False,
+ return_defaults,
+ False,
+ render_nulls,
+ )
def bulk_update_mappings(self, mapper, mappings):
"""Perform a bulk update of the given list of mapping dictionaries.
@@ -2673,25 +2803,41 @@ class Session(_SessionClassMethods):
"""
self._bulk_save_mappings(
- mapper, mappings, True, False, False, False, False)
+ mapper, mappings, True, False, False, False, False
+ )
def _bulk_save_mappings(
- self, mapper, mappings, isupdate, isstates,
- return_defaults, update_changed_only, render_nulls):
+ self,
+ mapper,
+ mappings,
+ isupdate,
+ isstates,
+ return_defaults,
+ update_changed_only,
+ render_nulls,
+ ):
mapper = _class_to_mapper(mapper)
self._flushing = True
- transaction = self.begin(
- subtransactions=True)
+ transaction = self.begin(subtransactions=True)
try:
if isupdate:
persistence._bulk_update(
- mapper, mappings, transaction,
- isstates, update_changed_only)
+ mapper,
+ mappings,
+ transaction,
+ isstates,
+ update_changed_only,
+ )
else:
persistence._bulk_insert(
- mapper, mappings, transaction,
- isstates, return_defaults, render_nulls)
+ mapper,
+ mappings,
+ transaction,
+ isstates,
+ return_defaults,
+ render_nulls,
+ )
transaction.commit()
except:
@@ -2700,8 +2846,7 @@ class Session(_SessionClassMethods):
finally:
self._flushing = False
- def is_modified(self, instance, include_collections=True,
- passive=True):
+ def is_modified(self, instance, include_collections=True, passive=True):
r"""Return ``True`` if the given instance has locally
modified attributes.
@@ -2775,16 +2920,15 @@ class Session(_SessionClassMethods):
dict_ = state.dict
for attr in state.manager.attributes:
- if \
- (
- not include_collections and
- hasattr(attr.impl, 'get_collection')
- ) or not hasattr(attr.impl, 'get_history'):
+ if (
+ not include_collections
+ and hasattr(attr.impl, "get_collection")
+ ) or not hasattr(attr.impl, "get_history"):
continue
- (added, unchanged, deleted) = \
- attr.impl.get_history(state, dict_,
- passive=attributes.NO_CHANGE)
+ (added, unchanged, deleted) = attr.impl.get_history(
+ state, dict_, passive=attributes.NO_CHANGE
+ )
if added or deleted:
return True
@@ -2898,9 +3042,12 @@ class Session(_SessionClassMethods):
"""
return util.IdentitySet(
- [state.obj()
- for state in self._dirty_states
- if state not in self._deleted])
+ [
+ state.obj()
+ for state in self._dirty_states
+ if state not in self._deleted
+ ]
+ )
@property
def deleted(self):
@@ -2961,10 +3108,16 @@ class sessionmaker(_SessionClassMethods):
"""
- def __init__(self, bind=None, class_=Session, autoflush=True,
- autocommit=False,
- expire_on_commit=True,
- info=None, **kw):
+ def __init__(
+ self,
+ bind=None,
+ class_=Session,
+ autoflush=True,
+ autocommit=False,
+ expire_on_commit=True,
+ info=None,
+ **kw
+ ):
r"""Construct a new :class:`.sessionmaker`.
All arguments here except for ``class_`` correspond to arguments
@@ -2992,12 +3145,12 @@ class sessionmaker(_SessionClassMethods):
constructor of newly created :class:`.Session` objects.
"""
- kw['bind'] = bind
- kw['autoflush'] = autoflush
- kw['autocommit'] = autocommit
- kw['expire_on_commit'] = expire_on_commit
+ kw["bind"] = bind
+ kw["autoflush"] = autoflush
+ kw["autocommit"] = autocommit
+ kw["expire_on_commit"] = expire_on_commit
if info is not None:
- kw['info'] = info
+ kw["info"] = info
self.kw = kw
# make our own subclass of the given class, so that
# events can be associated with it specifically.
@@ -3015,10 +3168,10 @@ class sessionmaker(_SessionClassMethods):
"""
for k, v in self.kw.items():
- if k == 'info' and 'info' in local_kw:
+ if k == "info" and "info" in local_kw:
d = v.copy()
- d.update(local_kw['info'])
- local_kw['info'] = d
+ d.update(local_kw["info"])
+ local_kw["info"] = d
else:
local_kw.setdefault(k, v)
return self.class_(**local_kw)
@@ -3038,7 +3191,7 @@ class sessionmaker(_SessionClassMethods):
return "%s(class_=%r, %s)" % (
self.__class__.__name__,
self.class_.__name__,
- ", ".join("%s=%r" % (k, v) for k, v in self.kw.items())
+ ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()),
)
@@ -3139,8 +3292,7 @@ def make_transient_to_detached(instance):
"""
state = attributes.instance_state(instance)
if state.session_id or state.key:
- raise sa_exc.InvalidRequestError(
- "Given object must be transient")
+ raise sa_exc.InvalidRequestError("Given object must be transient")
state.key = state.mapper._identity_key_from_state(state)
if state._deleted:
del state._deleted
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 944dc8177..c36d8817b 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -18,8 +18,16 @@ from .. import inspection
from .. import exc as sa_exc
from . import exc as orm_exc, interfaces
from .path_registry import PathRegistry
-from .base import PASSIVE_NO_RESULT, SQL_OK, NEVER_SET, ATTR_WAS_SET, \
- NO_VALUE, PASSIVE_NO_INITIALIZE, INIT_OK, PASSIVE_OFF
+from .base import (
+ PASSIVE_NO_RESULT,
+ SQL_OK,
+ NEVER_SET,
+ ATTR_WAS_SET,
+ NO_VALUE,
+ PASSIVE_NO_INITIALIZE,
+ INIT_OK,
+ PASSIVE_OFF,
+)
from . import base
@@ -106,10 +114,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
"""
return util.ImmutableProperties(
- dict(
- (key, AttributeState(self, key))
- for key in self.manager
- )
+ dict((key, AttributeState(self, key)) for key in self.manager)
)
@property
@@ -121,8 +126,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
:ref:`session_object_states`
"""
- return self.key is None and \
- not self._attached
+ return self.key is None and not self._attached
@property
def pending(self):
@@ -134,8 +138,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
:ref:`session_object_states`
"""
- return self.key is None and \
- self._attached
+ return self.key is None and self._attached
@property
def deleted(self):
@@ -164,8 +167,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
:ref:`session_object_states`
"""
- return self.key is not None and \
- self._attached and self._deleted
+ return self.key is not None and self._attached and self._deleted
@property
def was_deleted(self):
@@ -210,8 +212,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
:ref:`session_object_states`
"""
- return self.key is not None and \
- self._attached and not self._deleted
+ return self.key is not None and self._attached and not self._deleted
@property
def detached(self):
@@ -227,8 +228,10 @@ class InstanceState(interfaces.InspectionAttrInfo):
@property
@util.dependencies("sqlalchemy.orm.session")
def _attached(self, sessionlib):
- return self.session_id is not None and \
- self.session_id in sessionlib._sessions
+ return (
+ self.session_id is not None
+ and self.session_id in sessionlib._sessions
+ )
def _track_last_known_value(self, key):
"""Track the last known value of a particular key after expiration
@@ -323,14 +326,14 @@ class InstanceState(interfaces.InspectionAttrInfo):
@classmethod
def _detach_states(self, states, session, to_transient=False):
- persistent_to_detached = \
+ persistent_to_detached = (
session.dispatch.persistent_to_detached or None
- deleted_to_detached = \
- session.dispatch.deleted_to_detached or None
- pending_to_transient = \
- session.dispatch.pending_to_transient or None
- persistent_to_transient = \
+ )
+ deleted_to_detached = session.dispatch.deleted_to_detached or None
+ pending_to_transient = session.dispatch.pending_to_transient or None
+ persistent_to_transient = (
session.dispatch.persistent_to_transient or None
+ )
for state in states:
deleted = state._deleted
@@ -448,23 +451,33 @@ class InstanceState(interfaces.InspectionAttrInfo):
return self._pending_mutations[key]
def __getstate__(self):
- state_dict = {'instance': self.obj()}
+ state_dict = {"instance": self.obj()}
state_dict.update(
- (k, self.__dict__[k]) for k in (
- 'committed_state', '_pending_mutations', 'modified',
- 'expired', 'callables', 'key', 'parents', 'load_options',
- 'class_', 'expired_attributes', 'info'
- ) if k in self.__dict__
+ (k, self.__dict__[k])
+ for k in (
+ "committed_state",
+ "_pending_mutations",
+ "modified",
+ "expired",
+ "callables",
+ "key",
+ "parents",
+ "load_options",
+ "class_",
+ "expired_attributes",
+ "info",
+ )
+ if k in self.__dict__
)
if self.load_path:
- state_dict['load_path'] = self.load_path.serialize()
+ state_dict["load_path"] = self.load_path.serialize()
- state_dict['manager'] = self.manager._serialize(self, state_dict)
+ state_dict["manager"] = self.manager._serialize(self, state_dict)
return state_dict
def __setstate__(self, state_dict):
- inst = state_dict['instance']
+ inst = state_dict["instance"]
if inst is not None:
self.obj = weakref.ref(inst, self._cleanup)
self.class_ = inst.__class__
@@ -473,20 +486,20 @@ class InstanceState(interfaces.InspectionAttrInfo):
# due to storage of state in "parents". "class_"
# also new.
self.obj = None
- self.class_ = state_dict['class_']
-
- self.committed_state = state_dict.get('committed_state', {})
- self._pending_mutations = state_dict.get('_pending_mutations', {})
- self.parents = state_dict.get('parents', {})
- self.modified = state_dict.get('modified', False)
- self.expired = state_dict.get('expired', False)
- if 'info' in state_dict:
- self.info.update(state_dict['info'])
- if 'callables' in state_dict:
- self.callables = state_dict['callables']
+ self.class_ = state_dict["class_"]
+
+ self.committed_state = state_dict.get("committed_state", {})
+ self._pending_mutations = state_dict.get("_pending_mutations", {})
+ self.parents = state_dict.get("parents", {})
+ self.modified = state_dict.get("modified", False)
+ self.expired = state_dict.get("expired", False)
+ if "info" in state_dict:
+ self.info.update(state_dict["info"])
+ if "callables" in state_dict:
+ self.callables = state_dict["callables"]
try:
- self.expired_attributes = state_dict['expired_attributes']
+ self.expired_attributes = state_dict["expired_attributes"]
except KeyError:
self.expired_attributes = set()
# 0.9 and earlier compat
@@ -495,30 +508,31 @@ class InstanceState(interfaces.InspectionAttrInfo):
self.expired_attributes.add(k)
del self.callables[k]
else:
- if 'expired_attributes' in state_dict:
- self.expired_attributes = state_dict['expired_attributes']
+ if "expired_attributes" in state_dict:
+ self.expired_attributes = state_dict["expired_attributes"]
else:
self.expired_attributes = set()
- self.__dict__.update([
- (k, state_dict[k]) for k in (
- 'key', 'load_options'
- ) if k in state_dict
- ])
+ self.__dict__.update(
+ [
+ (k, state_dict[k])
+ for k in ("key", "load_options")
+ if k in state_dict
+ ]
+ )
if self.key:
try:
self.identity_token = self.key[2]
except IndexError:
# 1.1 and earlier compat before identity_token
assert len(self.key) == 2
- self.key = self.key + (None, )
+ self.key = self.key + (None,)
self.identity_token = None
- if 'load_path' in state_dict:
- self.load_path = PathRegistry.\
- deserialize(state_dict['load_path'])
+ if "load_path" in state_dict:
+ self.load_path = PathRegistry.deserialize(state_dict["load_path"])
- state_dict['manager'](self, inst, state_dict)
+ state_dict["manager"](self, inst, state_dict)
def _reset(self, dict_, key):
"""Remove the given attribute and any
@@ -532,25 +546,29 @@ class InstanceState(interfaces.InspectionAttrInfo):
self.callables.pop(key, None)
def _copy_callables(self, from_):
- if 'callables' in from_.__dict__:
+ if "callables" in from_.__dict__:
self.callables = dict(from_.callables)
@classmethod
def _instance_level_callable_processor(cls, manager, fn, key):
impl = manager[key].impl
if impl.collection:
+
def _set_callable(state, dict_, row):
- if 'callables' not in state.__dict__:
+ if "callables" not in state.__dict__:
state.callables = {}
old = dict_.pop(key, None)
if old is not None:
impl._invalidate_collection(old)
state.callables[key] = fn
+
else:
+
def _set_callable(state, dict_, row):
- if 'callables' not in state.__dict__:
+ if "callables" not in state.__dict__:
state.callables = {}
state.callables[key] = fn
+
return _set_callable
def _expire(self, dict_, modified_set):
@@ -563,15 +581,18 @@ class InstanceState(interfaces.InspectionAttrInfo):
self._strong_obj = None
- if '_pending_mutations' in self.__dict__:
- del self.__dict__['_pending_mutations']
+ if "_pending_mutations" in self.__dict__:
+ del self.__dict__["_pending_mutations"]
- if 'parents' in self.__dict__:
- del self.__dict__['parents']
+ if "parents" in self.__dict__:
+ del self.__dict__["parents"]
self.expired_attributes.update(
- [impl.key for impl in self.manager._scalar_loader_impls
- if impl.expire_missing or impl.key in dict_]
+ [
+ impl.key
+ for impl in self.manager._scalar_loader_impls
+ if impl.expire_missing or impl.key in dict_
+ ]
)
if self.callables:
@@ -584,8 +605,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
if self._last_known_values:
self._last_known_values.update(
- (k, dict_[k]) for k in self._last_known_values
- if k in dict_
+ (k, dict_[k]) for k in self._last_known_values if k in dict_
)
for key in self.manager._all_key_set.intersection(dict_):
@@ -594,17 +614,14 @@ class InstanceState(interfaces.InspectionAttrInfo):
self.manager.dispatch.expire(self, None)
def _expire_attributes(self, dict_, attribute_names, no_loader=False):
- pending = self.__dict__.get('_pending_mutations', None)
+ pending = self.__dict__.get("_pending_mutations", None)
callables = self.callables
for key in attribute_names:
impl = self.manager[key].impl
if impl.accepts_scalar_loader:
- if no_loader and (
- impl.callable_ or
- key in callables
- ):
+ if no_loader and (impl.callable_ or key in callables):
continue
self.expired_attributes.add(key)
@@ -614,8 +631,11 @@ class InstanceState(interfaces.InspectionAttrInfo):
if impl.collection and old is not NO_VALUE:
impl._invalidate_collection(old)
- if self._last_known_values and key in self._last_known_values \
- and old is not NO_VALUE:
+ if (
+ self._last_known_values
+ and key in self._last_known_values
+ and old is not NO_VALUE
+ ):
self._last_known_values[key] = old
self.committed_state.pop(key, None)
@@ -634,8 +654,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
if not passive & SQL_OK:
return PASSIVE_NO_RESULT
- toload = self.expired_attributes.\
- intersection(self.unmodified)
+ toload = self.expired_attributes.intersection(self.unmodified)
self.manager.deferred_scalar_loader(self, toload)
@@ -656,9 +675,11 @@ class InstanceState(interfaces.InspectionAttrInfo):
def unmodified_intersection(self, keys):
"""Return self.unmodified.intersection(keys)."""
-
- return set(keys).intersection(self.manager).\
- difference(self.committed_state)
+ return (
+ set(keys)
+ .intersection(self.manager)
+ .difference(self.committed_state)
+ )
@property
def unloaded(self):
@@ -668,9 +689,11 @@ class InstanceState(interfaces.InspectionAttrInfo):
was never populated or modified.
"""
- return set(self.manager).\
- difference(self.committed_state).\
- difference(self.dict)
+ return (
+ set(self.manager)
+ .difference(self.committed_state)
+ .difference(self.dict)
+ )
@property
def unloaded_expirable(self):
@@ -681,13 +704,16 @@ class InstanceState(interfaces.InspectionAttrInfo):
"""
return self.unloaded.intersection(
- attr for attr in self.manager
- if self.manager[attr].impl.expire_missing)
+ attr
+ for attr in self.manager
+ if self.manager[attr].impl.expire_missing
+ )
@property
def _unloaded_non_object(self):
return self.unloaded.intersection(
- attr for attr in self.manager
+ attr
+ for attr in self.manager
if self.manager[attr].impl.accepts_scalar_loader
)
@@ -695,14 +721,16 @@ class InstanceState(interfaces.InspectionAttrInfo):
return None
def _modified_event(
- self, dict_, attr, previous, collection=False, is_userland=False):
+ self, dict_, attr, previous, collection=False, is_userland=False
+ ):
if attr:
if not attr.send_modified_events:
return
if is_userland and attr.key not in dict_:
raise sa_exc.InvalidRequestError(
"Can't flag attribute '%s' modified; it's not present in "
- "the object state" % attr.key)
+ "the object state" % attr.key
+ )
if attr.key not in self.committed_state or is_userland:
if collection:
if previous is NEVER_SET:
@@ -718,8 +746,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
# assert self._strong_obj is None or self.modified
- if (self.session_id and self._strong_obj is None) \
- or not self.modified:
+ if (self.session_id and self._strong_obj is None) or not self.modified:
self.modified = True
instance_dict = self._instance_dict()
if instance_dict:
@@ -737,10 +764,8 @@ class InstanceState(interfaces.InspectionAttrInfo):
"Can't emit change event for attribute '%s' - "
"parent object of type %s has been garbage "
"collected."
- % (
- self.manager[attr.key],
- base.state_class_str(self)
- ))
+ % (self.manager[attr.key], base.state_class_str(self))
+ )
def _commit(self, dict_, keys):
"""Commit attributes.
@@ -758,17 +783,18 @@ class InstanceState(interfaces.InspectionAttrInfo):
self.expired = False
self.expired_attributes.difference_update(
- set(keys).intersection(dict_))
+ set(keys).intersection(dict_)
+ )
# the per-keys commit removes object-level callables,
# while that of commit_all does not. it's not clear
# if this behavior has a clear rationale, however tests do
# ensure this is what it does.
if self.callables:
- for key in set(self.callables).\
- intersection(keys).\
- intersection(dict_):
- del self.callables[key]
+ for key in (
+ set(self.callables).intersection(keys).intersection(dict_)
+ ):
+ del self.callables[key]
def _commit_all(self, dict_, instance_dict=None):
"""commit all attributes unconditionally.
@@ -797,8 +823,8 @@ class InstanceState(interfaces.InspectionAttrInfo):
state.committed_state.clear()
- if '_pending_mutations' in state_dict:
- del state_dict['_pending_mutations']
+ if "_pending_mutations" in state_dict:
+ del state_dict["_pending_mutations"]
state.expired_attributes.difference_update(dict_)
@@ -848,7 +874,8 @@ class AttributeState(object):
"""
return self.state.manager[self.key].__get__(
- self.state.obj(), self.state.class_)
+ self.state.obj(), self.state.class_
+ )
@property
def history(self):
@@ -866,8 +893,7 @@ class AttributeState(object):
:func:`.attributes.get_history` - underlying function
"""
- return self.state.get_history(self.key,
- PASSIVE_NO_INITIALIZE)
+ return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE)
def load_history(self):
"""Return the current pre-flush change history for
@@ -885,8 +911,7 @@ class AttributeState(object):
.. versionadded:: 0.9.0
"""
- return self.state.get_history(self.key,
- PASSIVE_OFF ^ INIT_OK)
+ return self.state.get_history(self.key, PASSIVE_OFF ^ INIT_OK)
class PendingCollection(object):
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 47791f9b9..5c972b26b 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -13,22 +13,27 @@ from .. import util, log, event
from ..sql import util as sql_util, visitors
from .. import sql
from . import (
- attributes, interfaces, exc as orm_exc, loading,
- unitofwork, util as orm_util, query
+ attributes,
+ interfaces,
+ exc as orm_exc,
+ loading,
+ unitofwork,
+ util as orm_util,
+ query,
)
from .state import InstanceState
from .util import _none_set, aliased
from . import properties
-from .interfaces import (
- LoaderStrategy, StrategizedProperty
-)
+from .interfaces import LoaderStrategy, StrategizedProperty
from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE
from .session import _state_session
import itertools
def _register_attribute(
- prop, mapper, useobject,
+ prop,
+ mapper,
+ useobject,
compare_function=None,
typecallable=None,
callable_=None,
@@ -51,8 +56,8 @@ def _register_attribute(
fn, opts = prop.parent.validators[prop.key]
listen_hooks.append(
lambda desc, prop: orm_util._validator_events(
- desc,
- prop.key, fn, **opts)
+ desc, prop.key, fn, **opts
+ )
)
if useobject:
@@ -65,9 +70,7 @@ def _register_attribute(
if backref:
listen_hooks.append(
lambda desc, prop: attributes.backref_listeners(
- desc,
- backref,
- uselist
+ desc, backref, uselist
)
)
@@ -83,8 +86,9 @@ def _register_attribute(
# on mappers not already being set up so we have to check each one.
for m in mapper.self_and_descendants:
- if prop is m._props.get(prop.key) and \
- not m.class_manager._attr_has_impl(prop.key):
+ if prop is m._props.get(
+ prop.key
+ ) and not m.class_manager._attr_has_impl(prop.key):
desc = attributes.register_attribute_impl(
m.class_,
@@ -94,9 +98,11 @@ def _register_attribute(
compare_function=compare_function,
useobject=useobject,
extension=attribute_ext,
- trackparent=useobject and (
- prop.single_parent or
- prop.direction is interfaces.ONETOMANY),
+ trackparent=useobject
+ and (
+ prop.single_parent
+ or prop.direction is interfaces.ONETOMANY
+ ),
typecallable=typecallable,
callable_=callable_,
active_history=active_history,
@@ -118,23 +124,31 @@ class UninstrumentedColumnLoader(LoaderStrategy):
if the argument is against the with_polymorphic selectable.
"""
- __slots__ = 'columns',
+
+ __slots__ = ("columns",)
def __init__(self, parent, strategy_key):
super(UninstrumentedColumnLoader, self).__init__(parent, strategy_key)
self.columns = self.parent_property.columns
def setup_query(
- self, context, entity, path, loadopt, adapter,
- column_collection=None, **kwargs):
+ self,
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection=None,
+ **kwargs
+ ):
for c in self.columns:
if adapter:
c = adapter.columns[c]
column_collection.append(c)
def create_row_processor(
- self, context, path, loadopt,
- mapper, result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
pass
@@ -143,16 +157,24 @@ class UninstrumentedColumnLoader(LoaderStrategy):
class ColumnLoader(LoaderStrategy):
"""Provide loading behavior for a :class:`.ColumnProperty`."""
- __slots__ = 'columns', 'is_composite'
+ __slots__ = "columns", "is_composite"
def __init__(self, parent, strategy_key):
super(ColumnLoader, self).__init__(parent, strategy_key)
self.columns = self.parent_property.columns
- self.is_composite = hasattr(self.parent_property, 'composite_class')
+ self.is_composite = hasattr(self.parent_property, "composite_class")
def setup_query(
- self, context, entity, path, loadopt,
- adapter, column_collection, memoized_populators, **kwargs):
+ self,
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ **kwargs
+ ):
for c in self.columns:
if adapter:
@@ -168,19 +190,23 @@ class ColumnLoader(LoaderStrategy):
self.is_class_level = True
coltype = self.columns[0].type
# TODO: check all columns ? check for foreign key as well?
- active_history = self.parent_property.active_history or \
- self.columns[0].primary_key or \
- mapper.version_id_col in set(self.columns)
+ active_history = (
+ self.parent_property.active_history
+ or self.columns[0].primary_key
+ or mapper.version_id_col in set(self.columns)
+ )
_register_attribute(
- self.parent_property, mapper, useobject=False,
+ self.parent_property,
+ mapper,
+ useobject=False,
compare_function=coltype.compare_values,
- active_history=active_history
+ active_history=active_history,
)
def create_row_processor(
- self, context, path,
- loadopt, mapper, result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
# look through list of columns represented here
# to see which, if any, is present in the row.
for col in self.columns:
@@ -201,8 +227,16 @@ class ExpressionColumnLoader(ColumnLoader):
super(ExpressionColumnLoader, self).__init__(parent, strategy_key)
def setup_query(
- self, context, entity, path, loadopt,
- adapter, column_collection, memoized_populators, **kwargs):
+ self,
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ **kwargs
+ ):
if loadopt and "expression" in loadopt.local_opts:
columns = [loadopt.local_opts["expression"]]
@@ -218,8 +252,8 @@ class ExpressionColumnLoader(ColumnLoader):
memoized_populators[self.parent_property] = fetch
def create_row_processor(
- self, context, path,
- loadopt, mapper, result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
# look through list of columns represented here
# to see which, if any, is present in the row.
if loadopt and "expression" in loadopt.local_opts:
@@ -239,9 +273,11 @@ class ExpressionColumnLoader(ColumnLoader):
self.is_class_level = True
_register_attribute(
- self.parent_property, mapper, useobject=False,
+ self.parent_property,
+ mapper,
+ useobject=False,
compare_function=self.columns[0].type.compare_values,
- accepts_scalar_loader=False
+ accepts_scalar_loader=False,
)
@@ -251,27 +287,29 @@ class ExpressionColumnLoader(ColumnLoader):
class DeferredColumnLoader(LoaderStrategy):
"""Provide loading behavior for a deferred :class:`.ColumnProperty`."""
- __slots__ = 'columns', 'group'
+ __slots__ = "columns", "group"
def __init__(self, parent, strategy_key):
super(DeferredColumnLoader, self).__init__(parent, strategy_key)
- if hasattr(self.parent_property, 'composite_class'):
- raise NotImplementedError("Deferred loading for composite "
- "types not implemented yet")
+ if hasattr(self.parent_property, "composite_class"):
+ raise NotImplementedError(
+ "Deferred loading for composite " "types not implemented yet"
+ )
self.columns = self.parent_property.columns
self.group = self.parent_property.group
def create_row_processor(
- self, context, path, loadopt,
- mapper, result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
# this path currently does not check the result
# for the column; this is because in most cases we are
# working just with the setup_query() directive which does
# not support this, and the behavior here should be consistent.
if not self.is_class_level:
- set_deferred_for_local_state = \
+ set_deferred_for_local_state = (
self.parent_property._deferred_column_loader
+ )
populators["new"].append((self.key, set_deferred_for_local_state))
else:
populators["expire"].append((self.key, False))
@@ -280,41 +318,56 @@ class DeferredColumnLoader(LoaderStrategy):
self.is_class_level = True
_register_attribute(
- self.parent_property, mapper, useobject=False,
+ self.parent_property,
+ mapper,
+ useobject=False,
compare_function=self.columns[0].type.compare_values,
callable_=self._load_for_state,
- expire_missing=False
+ expire_missing=False,
)
def setup_query(
- self, context, entity, path, loadopt,
- adapter, column_collection, memoized_populators,
- only_load_props=None, **kw):
+ self,
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ only_load_props=None,
+ **kw
+ ):
if (
(
- loadopt and
- 'undefer_pks' in loadopt.local_opts and
- set(self.columns).intersection(
- self.parent._should_undefer_in_wildcard)
- )
- or
- (
- loadopt and
- self.group and
- loadopt.local_opts.get('undefer_group_%s' % self.group, False)
+ loadopt
+ and "undefer_pks" in loadopt.local_opts
+ and set(self.columns).intersection(
+ self.parent._should_undefer_in_wildcard
+ )
)
- or
- (
- only_load_props and self.key in only_load_props
+ or (
+ loadopt
+ and self.group
+ and loadopt.local_opts.get(
+ "undefer_group_%s" % self.group, False
+ )
)
+ or (only_load_props and self.key in only_load_props)
):
self.parent_property._get_strategy(
(("deferred", False), ("instrument", True))
).setup_query(
- context, entity,
- path, loadopt, adapter,
- column_collection, memoized_populators, **kw)
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ **kw
+ )
elif self.is_class_level:
memoized_populators[self.parent_property] = _SET_DEFERRED_EXPIRED
else:
@@ -331,11 +384,11 @@ class DeferredColumnLoader(LoaderStrategy):
if self.group:
toload = [
- p.key for p in
- localparent.iterate_properties
- if isinstance(p, StrategizedProperty) and
- isinstance(p.strategy, DeferredColumnLoader) and
- p.group == self.group
+ p.key
+ for p in localparent.iterate_properties
+ if isinstance(p, StrategizedProperty)
+ and isinstance(p.strategy, DeferredColumnLoader)
+ and p.group == self.group
]
else:
toload = [self.key]
@@ -347,14 +400,17 @@ class DeferredColumnLoader(LoaderStrategy):
if session is None:
raise orm_exc.DetachedInstanceError(
"Parent instance %s is not bound to a Session; "
- "deferred load operation of attribute '%s' cannot proceed" %
- (orm_util.state_str(state), self.key)
+ "deferred load operation of attribute '%s' cannot proceed"
+ % (orm_util.state_str(state), self.key)
)
query = session.query(localparent)
- if loading.load_on_ident(
- query, state.key,
- only_load_props=group, refresh_state=state) is None:
+ if (
+ loading.load_on_ident(
+ query, state.key, only_load_props=group, refresh_state=state
+ )
+ is None
+ ):
raise orm_exc.ObjectDeletedError(state)
return attributes.ATTR_WAS_SET
@@ -378,7 +434,7 @@ class LoadDeferredColumns(object):
class AbstractRelationshipLoader(LoaderStrategy):
"""LoaderStratgies which deal with related objects."""
- __slots__ = 'mapper', 'target', 'uselist'
+ __slots__ = "mapper", "target", "uselist"
def __init__(self, parent, strategy_key):
super(AbstractRelationshipLoader, self).__init__(parent, strategy_key)
@@ -414,19 +470,21 @@ class NoLoader(AbstractRelationshipLoader):
self.is_class_level = True
_register_attribute(
- self.parent_property, mapper,
+ self.parent_property,
+ mapper,
useobject=True,
typecallable=self.parent_property.collection_class,
)
def create_row_processor(
- self, context, path, loadopt, mapper,
- result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
def invoke_no_load(state, dict_, row):
if self.uselist:
state.manager.get_impl(self.key).initialize(state, dict_)
else:
dict_[self.key] = None
+
populators["new"].append((self.key, invoke_no_load))
@@ -443,10 +501,18 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
"""
__slots__ = (
- '_lazywhere', '_rev_lazywhere', 'use_get', '_bind_to_col',
- '_equated_columns', '_rev_bind_to_col', '_rev_equated_columns',
- '_simple_lazy_clause', '_raise_always', '_raise_on_sql',
- '_bakery')
+ "_lazywhere",
+ "_rev_lazywhere",
+ "use_get",
+ "_bind_to_col",
+ "_equated_columns",
+ "_rev_bind_to_col",
+ "_rev_equated_columns",
+ "_simple_lazy_clause",
+ "_raise_always",
+ "_raise_on_sql",
+ "_bakery",
+ )
def __init__(self, parent, strategy_key):
super(LazyLoader, self).__init__(parent, strategy_key)
@@ -454,25 +520,23 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql"
join_condition = self.parent_property._join_condition
- self._lazywhere, \
- self._bind_to_col, \
- self._equated_columns = join_condition.create_lazy_clause()
+ self._lazywhere, self._bind_to_col, self._equated_columns = (
+ join_condition.create_lazy_clause()
+ )
- self._rev_lazywhere, \
- self._rev_bind_to_col, \
- self._rev_equated_columns = join_condition.create_lazy_clause(
- reverse_direction=True)
+ self._rev_lazywhere, self._rev_bind_to_col, self._rev_equated_columns = join_condition.create_lazy_clause(
+ reverse_direction=True
+ )
self.logger.info("%s lazy loading clause %s", self, self._lazywhere)
# determine if our "lazywhere" clause is the same as the mapper's
# get() clause. then we can just use mapper.get()
- self.use_get = not self.uselist and \
- self.mapper._get_clause[0].compare(
- self._lazywhere,
- use_proxies=True,
- equivalents=self.mapper._equivalent_columns
- )
+ self.use_get = not self.uselist and self.mapper._get_clause[0].compare(
+ self._lazywhere,
+ use_proxies=True,
+ equivalents=self.mapper._equivalent_columns,
+ )
if self.use_get:
for col in list(self._equated_columns):
@@ -480,16 +544,17 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
for c in self.mapper._equivalent_columns[col]:
self._equated_columns[c] = self._equated_columns[col]
- self.logger.info("%s will use query.get() to "
- "optimize instance loads", self)
+ self.logger.info(
+ "%s will use query.get() to " "optimize instance loads", self
+ )
def init_class_attribute(self, mapper):
self.is_class_level = True
active_history = (
- self.parent_property.active_history or
- self.parent_property.direction is not interfaces.MANYTOONE or
- not self.use_get
+ self.parent_property.active_history
+ or self.parent_property.direction is not interfaces.MANYTOONE
+ or not self.use_get
)
# MANYTOONE currently only needs the
@@ -504,28 +569,29 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
useobject=True,
callable_=self._load_for_state,
typecallable=self.parent_property.collection_class,
- active_history=active_history
+ active_history=active_history,
)
def _memoized_attr__simple_lazy_clause(self):
- criterion, bind_to_col = (
- self._lazywhere,
- self._bind_to_col
- )
+ criterion, bind_to_col = (self._lazywhere, self._bind_to_col)
params = []
def visit_bindparam(bindparam):
bindparam.unique = False
if bindparam._identifying_key in bind_to_col:
- params.append((
- bindparam.key, bind_to_col[bindparam._identifying_key],
- None))
+ params.append(
+ (
+ bindparam.key,
+ bind_to_col[bindparam._identifying_key],
+ None,
+ )
+ )
elif bindparam.callable is None:
params.append((bindparam.key, None, bindparam.value))
criterion = visitors.cloned_traverse(
- criterion, {}, {'bindparam': visit_bindparam}
+ criterion, {}, {"bindparam": visit_bindparam}
)
return criterion, params
@@ -535,7 +601,8 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
if state is None:
return sql_util.adapt_criterion_to_null(
- criterion, [key for key, ident, value in param_keys])
+ criterion, [key for key, ident, value in param_keys]
+ )
mapper = self.parent_property.parent
@@ -550,10 +617,12 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
if ident is not None:
if passive and passive & attributes.LOAD_AGAINST_COMMITTED:
value = mapper._get_committed_state_attr_by_column(
- state, dict_, ident, passive)
+ state, dict_, ident, passive
+ )
else:
value = mapper._get_state_attr_by_column(
- state, dict_, ident, passive)
+ state, dict_, ident, passive
+ )
params[key] = value
@@ -567,21 +636,19 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
def _load_for_state(self, state, passive):
if not state.key and (
- (
- not self.parent_property.load_on_pending
- and not state._load_pending
- )
- or not state.session_id
+ (
+ not self.parent_property.load_on_pending
+ and not state._load_pending
+ )
+ or not state.session_id
):
return attributes.ATTR_EMPTY
pending = not state.key
primary_key_identity = None
- if (
- (not passive & attributes.SQL_OK and not self.use_get)
- or
- (not passive & attributes.NON_PERSISTENT_OK and pending)
+ if (not passive & attributes.SQL_OK and not self.use_get) or (
+ not passive & attributes.NON_PERSISTENT_OK and pending
):
return attributes.PASSIVE_NO_RESULT
@@ -595,17 +662,15 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
raise orm_exc.DetachedInstanceError(
"Parent instance %s is not bound to a Session; "
- "lazy load operation of attribute '%s' cannot proceed" %
- (orm_util.state_str(state), self.key)
+ "lazy load operation of attribute '%s' cannot proceed"
+ % (orm_util.state_str(state), self.key)
)
# if we have a simple primary key load, check the
# identity map without generating a Query at all
if self.use_get:
primary_key_identity = self._get_ident_for_use_get(
- session,
- state,
- passive
+ session, state, passive
)
if attributes.PASSIVE_NO_RESULT in primary_key_identity:
return attributes.PASSIVE_NO_RESULT
@@ -620,18 +685,23 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
# does this, including how it decides what the correct
# identity_token would be for this identity.
instance = session.query()._identity_lookup(
- self.mapper, primary_key_identity, passive=passive,
- lazy_loaded_from=state
+ self.mapper,
+ primary_key_identity,
+ passive=passive,
+ lazy_loaded_from=state,
)
if instance is not None:
return instance
- elif not passive & attributes.SQL_OK or \
- not passive & attributes.RELATED_OBJECT_OK:
+ elif (
+ not passive & attributes.SQL_OK
+ or not passive & attributes.RELATED_OBJECT_OK
+ ):
return attributes.PASSIVE_NO_RESULT
return self._emit_lazyload(
- session, state, primary_key_identity, passive)
+ session, state, primary_key_identity, passive
+ )
def _get_ident_for_use_get(self, session, state, passive):
instance_mapper = state.manager.mapper
@@ -644,11 +714,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
dict_ = state.dict
return [
- get_attr(
- state,
- dict_,
- self._equated_columns[pk],
- passive=passive)
+ get_attr(state, dict_, self._equated_columns[pk], passive=passive)
for pk in self.mapper.primary_key
]
@@ -656,11 +722,10 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
def _memoized_attr__bakery(self, baked):
return baked.bakery(size=50)
- @util.dependencies(
- "sqlalchemy.orm.strategy_options")
+ @util.dependencies("sqlalchemy.orm.strategy_options")
def _emit_lazyload(
- self, strategy_options, session, state,
- primary_key_identity, passive):
+ self, strategy_options, session, state, primary_key_identity, passive
+ ):
# emit lazy load now using BakedQuery, to cut way down on the overhead
# of generating queries.
# there are two big things we are trying to guard against here:
@@ -688,15 +753,18 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
q.add_criteria(
lambda q: q._adapt_all_clauses()._with_invoke_all_eagers(False),
- self.parent_property)
+ self.parent_property,
+ )
if not self.parent_property.bake_queries:
q.spoil(full=True)
if self.parent_property.secondary is not None:
q.add_criteria(
- lambda q:
- q.select_from(self.mapper, self.parent_property.secondary))
+ lambda q: q.select_from(
+ self.mapper, self.parent_property.secondary
+ )
+ )
pending = not state.key
@@ -712,35 +780,38 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
# is usually a throwaway object.
effective_path = state.load_path[self.parent_property]
- q._add_lazyload_options(
- state.load_options, effective_path
- )
+ q._add_lazyload_options(state.load_options, effective_path)
if self.use_get:
if self._raise_on_sql:
self._invoke_raise_load(state, passive, "raise_on_sql")
- return q(session).\
- with_post_criteria(lambda q: q._set_lazyload_from(state)).\
- _load_on_pk_identity(
- session.query(self.mapper),
- primary_key_identity)
+ return (
+ q(session)
+ .with_post_criteria(lambda q: q._set_lazyload_from(state))
+ ._load_on_pk_identity(
+ session.query(self.mapper), primary_key_identity
+ )
+ )
if self.parent_property.order_by:
q.add_criteria(
- lambda q:
- q.order_by(*util.to_list(self.parent_property.order_by)))
+ lambda q: q.order_by(
+ *util.to_list(self.parent_property.order_by)
+ )
+ )
for rev in self.parent_property._reverse_property:
# reverse props that are MANYTOONE are loading *this*
# object from get(), so don't need to eager out to those.
- if rev.direction is interfaces.MANYTOONE and \
- rev._use_get and \
- not isinstance(rev.strategy, LazyLoader):
+ if (
+ rev.direction is interfaces.MANYTOONE
+ and rev._use_get
+ and not isinstance(rev.strategy, LazyLoader)
+ ):
q.add_criteria(
- lambda q:
- q.options(
+ lambda q: q.options(
strategy_options.Load.for_existing_path(
q._current_path[rev.parent]
).lazyload(rev.key)
@@ -750,8 +821,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
lazy_clause, params = self._generate_lazy_clause(state, passive)
if pending:
- if util.has_intersection(
- orm_util._none_set, params.values()):
+ if util.has_intersection(orm_util._none_set, params.values()):
return None
elif util.has_intersection(orm_util._never_set, params.values()):
@@ -769,9 +839,12 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
q._params = params
return q
- result = q(session).\
- with_post_criteria(lambda q: q._set_lazyload_from(state)).\
- with_post_criteria(set_default_params).all()
+ result = (
+ q(session)
+ .with_post_criteria(lambda q: q._set_lazyload_from(state))
+ .with_post_criteria(set_default_params)
+ .all()
+ )
if self.uselist:
return result
else:
@@ -781,15 +854,16 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
util.warn(
"Multiple rows returned with "
"uselist=False for lazily-loaded attribute '%s' "
- % self.parent_property)
+ % self.parent_property
+ )
return result[0]
else:
return None
def create_row_processor(
- self, context, path, loadopt,
- mapper, result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
key = self.key
if not self.is_class_level:
@@ -802,11 +876,12 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
# attribute - "eager" attributes always have a
# class-level lazyloader installed.
set_lazy_callable = InstanceState._instance_level_callable_processor(
- mapper.class_manager,
- LoadLazyAttribute(key, self), key)
+ mapper.class_manager, LoadLazyAttribute(key, self), key
+ )
populators["new"].append((self.key, set_lazy_callable))
elif context.populate_existing or mapper.always_refresh:
+
def reset_for_lazy_callable(state, dict_, row):
# we are the primary manager for this attribute on
# this class - reset its
@@ -842,19 +917,26 @@ class ImmediateLoader(AbstractRelationshipLoader):
__slots__ = ()
def init_class_attribute(self, mapper):
- self.parent_property.\
- _get_strategy((("lazy", "select"),)).\
- init_class_attribute(mapper)
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
def setup_query(
- self, context, entity,
- path, loadopt, adapter, column_collection=None,
- parentmapper=None, **kwargs):
+ self,
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection=None,
+ parentmapper=None,
+ **kwargs
+ ):
pass
def create_row_processor(
- self, context, path, loadopt,
- mapper, result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
def load_immediate(state, dict_, row):
state.get_impl(self.key).get(state, dict_)
@@ -864,22 +946,28 @@ class ImmediateLoader(AbstractRelationshipLoader):
@log.class_logger
@properties.RelationshipProperty.strategy_for(lazy="subquery")
class SubqueryLoader(AbstractRelationshipLoader):
- __slots__ = 'join_depth',
+ __slots__ = ("join_depth",)
def __init__(self, parent, strategy_key):
super(SubqueryLoader, self).__init__(parent, strategy_key)
self.join_depth = self.parent_property.join_depth
def init_class_attribute(self, mapper):
- self.parent_property.\
- _get_strategy((("lazy", "select"),)).\
- init_class_attribute(mapper)
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
def setup_query(
- self, context, entity,
- path, loadopt, adapter,
- column_collection=None,
- parentmapper=None, **kwargs):
+ self,
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection=None,
+ parentmapper=None,
+ **kwargs
+ ):
if not context.query._enable_eagerloads:
return
@@ -891,16 +979,16 @@ class SubqueryLoader(AbstractRelationshipLoader):
# build up a path indicating the path from the leftmost
# entity to the thing we're subquery loading.
with_poly_info = path.get(
- context.attributes,
- "path_with_polymorphic", None)
+ context.attributes, "path_with_polymorphic", None
+ )
if with_poly_info is not None:
effective_entity = with_poly_info.entity
else:
effective_entity = self.mapper
subq_path = context.attributes.get(
- ('subquery_path', None),
- orm_util.PathRegistry.root)
+ ("subquery_path", None), orm_util.PathRegistry.root
+ )
subq_path = subq_path + path
@@ -909,27 +997,33 @@ class SubqueryLoader(AbstractRelationshipLoader):
if not path.contains(context.attributes, "loader"):
if self.join_depth:
if (
- (context.query._current_path.length
- if context.query._current_path else 0) +
- path.length
+ (
+ context.query._current_path.length
+ if context.query._current_path
+ else 0
+ )
+ + path.length
) / 2 > self.join_depth:
return
elif subq_path.contains_mapper(self.mapper):
return
- leftmost_mapper, leftmost_attr, leftmost_relationship = \
- self._get_leftmost(subq_path)
+ leftmost_mapper, leftmost_attr, leftmost_relationship = self._get_leftmost(
+ subq_path
+ )
orig_query = context.attributes.get(
- ("orig_query", SubqueryLoader),
- context.query)
+ ("orig_query", SubqueryLoader), context.query
+ )
# generate a new Query from the original, then
# produce a subquery from it.
left_alias = self._generate_from_original_query(
- orig_query, leftmost_mapper,
- leftmost_attr, leftmost_relationship,
- entity.entity_zero
+ orig_query,
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ entity.entity_zero,
)
# generate another Query that will join the
@@ -940,17 +1034,18 @@ class SubqueryLoader(AbstractRelationshipLoader):
q = orig_query.session.query(effective_entity)
q._attributes = {
("orig_query", SubqueryLoader): orig_query,
- ('subquery_path', None): subq_path
+ ("subquery_path", None): subq_path,
}
q = q._set_enable_single_crit(False)
- to_join, local_attr, parent_alias = \
- self._prep_for_joins(left_alias, subq_path)
+ to_join, local_attr, parent_alias = self._prep_for_joins(
+ left_alias, subq_path
+ )
q = q.order_by(*local_attr)
q = q.add_columns(*local_attr)
q = self._apply_joins(
- q, to_join, left_alias,
- parent_alias, effective_entity)
+ q, to_join, left_alias, parent_alias, effective_entity
+ )
q = self._setup_options(q, subq_path, orig_query, effective_entity)
q = self._setup_outermost_orderby(q)
@@ -964,21 +1059,20 @@ class SubqueryLoader(AbstractRelationshipLoader):
subq_mapper = orm_util._class_to_mapper(subq_path[0])
# determine attributes of the leftmost mapper
- if self.parent.isa(subq_mapper) and \
- self.parent_property is subq_path[1]:
- leftmost_mapper, leftmost_prop = \
- self.parent, self.parent_property
+ if (
+ self.parent.isa(subq_mapper)
+ and self.parent_property is subq_path[1]
+ ):
+ leftmost_mapper, leftmost_prop = self.parent, self.parent_property
else:
- leftmost_mapper, leftmost_prop = \
- subq_mapper, \
- subq_path[1]
+ leftmost_mapper, leftmost_prop = subq_mapper, subq_path[1]
leftmost_cols = leftmost_prop.local_columns
leftmost_attr = [
getattr(
- subq_path[0].entity,
- leftmost_mapper._columntoproperty[c].key)
+ subq_path[0].entity, leftmost_mapper._columntoproperty[c].key
+ )
for c in leftmost_cols
]
@@ -986,8 +1080,11 @@ class SubqueryLoader(AbstractRelationshipLoader):
def _generate_from_original_query(
self,
- orig_query, leftmost_mapper,
- leftmost_attr, leftmost_relationship, orig_entity
+ orig_query,
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ orig_entity,
):
# reformat the original query
# to look only for significant columns
@@ -999,11 +1096,16 @@ class SubqueryLoader(AbstractRelationshipLoader):
# all entities mentioned in things like WHERE, JOIN, etc.
if not q._from_obj:
q._set_select_from(
- list(set([
- ent['entity'] for ent in orig_query.column_descriptions
- if ent['entity'] is not None
- ])),
- False
+ list(
+ set(
+ [
+ ent["entity"]
+ for ent in orig_query.column_descriptions
+ if ent["entity"] is not None
+ ]
+ )
+ ),
+ False,
)
# select from the identity columns of the outer (specifically, these
@@ -1037,8 +1139,8 @@ class SubqueryLoader(AbstractRelationshipLoader):
embed_q = q.with_labels().subquery()
left_alias = orm_util.AliasedClass(
- leftmost_mapper, embed_q,
- use_mapper_path=True)
+ leftmost_mapper, embed_q, use_mapper_path=True
+ )
return left_alias
def _prep_for_joins(self, left_alias, subq_path):
@@ -1077,8 +1179,8 @@ class SubqueryLoader(AbstractRelationshipLoader):
# alias a plain mapper as we may be
# joining multiple times
parent_alias = orm_util.AliasedClass(
- info.entity,
- use_mapper_path=True)
+ info.entity, use_mapper_path=True
+ )
local_cols = self.parent_property.local_columns
@@ -1089,8 +1191,8 @@ class SubqueryLoader(AbstractRelationshipLoader):
return to_join, local_attr, parent_alias
def _apply_joins(
- self, q, to_join, left_alias, parent_alias,
- effective_entity):
+ self, q, to_join, left_alias, parent_alias, effective_entity
+ ):
ltj = len(to_join)
if ltj == 1:
@@ -1100,7 +1202,9 @@ class SubqueryLoader(AbstractRelationshipLoader):
elif ltj == 2:
to_join = [
getattr(left_alias, to_join[0][1]).of_type(parent_alias),
- getattr(parent_alias, to_join[-1][1]).of_type(effective_entity)
+ getattr(parent_alias, to_join[-1][1]).of_type(
+ effective_entity
+ ),
]
elif ltj > 2:
middle = [
@@ -1108,8 +1212,9 @@ class SubqueryLoader(AbstractRelationshipLoader):
orm_util.AliasedClass(item[0])
if not inspect(item[0]).is_aliased_class
else item[0].entity,
- item[1]
- ) for item in to_join[1:-1]
+ item[1],
+ )
+ for item in to_join[1:-1]
]
inner = []
@@ -1123,11 +1228,15 @@ class SubqueryLoader(AbstractRelationshipLoader):
inner.append(attr)
- to_join = [
- getattr(left_alias, to_join[0][1]).of_type(inner[0].parent)
- ] + inner + [
- getattr(parent_alias, to_join[-1][1]).of_type(effective_entity)
- ]
+ to_join = (
+ [getattr(left_alias, to_join[0][1]).of_type(inner[0].parent)]
+ + inner
+ + [
+ getattr(parent_alias, to_join[-1][1]).of_type(
+ effective_entity
+ )
+ ]
+ )
for attr in to_join:
q = q.join(attr, from_joinpoint=True)
@@ -1151,13 +1260,9 @@ class SubqueryLoader(AbstractRelationshipLoader):
# this really only picks up the "secondary" table
# right now.
eagerjoin = q._from_obj[0]
- eager_order_by = \
- eagerjoin._target_adapter.\
- copy_and_process(
- util.to_list(
- self.parent_property.order_by
- )
- )
+ eager_order_by = eagerjoin._target_adapter.copy_and_process(
+ util.to_list(self.parent_property.order_by)
+ )
q = q.order_by(*eager_order_by)
return q
@@ -1167,6 +1272,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
first moment a value is needed.
"""
+
_data = None
def __init__(self, subq):
@@ -1180,10 +1286,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
def _load(self):
self._data = dict(
(k, [vv[0] for vv in v])
- for k, v in itertools.groupby(
- self.subq,
- lambda x: x[1:]
- )
+ for k, v in itertools.groupby(self.subq, lambda x: x[1:])
)
def loader(self, state, dict_, row):
@@ -1191,17 +1294,17 @@ class SubqueryLoader(AbstractRelationshipLoader):
self._load()
def create_row_processor(
- self, context, path, loadopt,
- mapper, result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
if not self.parent.class_manager[self.key].impl.supports_population:
raise sa_exc.InvalidRequestError(
"'%s' does not support object "
- "population - eager loading cannot be applied." %
- self)
+ "population - eager loading cannot be applied." % self
+ )
path = path[self.parent_property]
- subq = path.get(context.attributes, 'subquery')
+ subq = path.get(context.attributes, "subquery")
if subq is None:
return
@@ -1220,65 +1323,67 @@ class SubqueryLoader(AbstractRelationshipLoader):
collections = path.get(context.attributes, "collections")
if collections is None:
collections = self._SubqCollections(subq)
- path.set(context.attributes, 'collections', collections)
+ path.set(context.attributes, "collections", collections)
if adapter:
local_cols = [adapter.columns[c] for c in local_cols]
if self.uselist:
self._create_collection_loader(
- context, collections, local_cols, populators)
+ context, collections, local_cols, populators
+ )
else:
self._create_scalar_loader(
- context, collections, local_cols, populators)
+ context, collections, local_cols, populators
+ )
def _create_collection_loader(
- self, context, collections, local_cols, populators):
+ self, context, collections, local_cols, populators
+ ):
def load_collection_from_subq(state, dict_, row):
collection = collections.get(
- tuple([row[col] for col in local_cols]),
- ()
+ tuple([row[col] for col in local_cols]), ()
+ )
+ state.get_impl(self.key).set_committed_value(
+ state, dict_, collection
)
- state.get_impl(self.key).\
- set_committed_value(state, dict_, collection)
def load_collection_from_subq_existing_row(state, dict_, row):
if self.key not in dict_:
load_collection_from_subq(state, dict_, row)
- populators["new"].append(
- (self.key, load_collection_from_subq))
+ populators["new"].append((self.key, load_collection_from_subq))
populators["existing"].append(
- (self.key, load_collection_from_subq_existing_row))
+ (self.key, load_collection_from_subq_existing_row)
+ )
if context.invoke_all_eagers:
populators["eager"].append((self.key, collections.loader))
def _create_scalar_loader(
- self, context, collections, local_cols, populators):
+ self, context, collections, local_cols, populators
+ ):
def load_scalar_from_subq(state, dict_, row):
collection = collections.get(
- tuple([row[col] for col in local_cols]),
- (None,)
+ tuple([row[col] for col in local_cols]), (None,)
)
if len(collection) > 1:
util.warn(
"Multiple rows returned with "
- "uselist=False for eagerly-loaded attribute '%s' "
- % self)
+ "uselist=False for eagerly-loaded attribute '%s' " % self
+ )
scalar = collection[0]
- state.get_impl(self.key).\
- set_committed_value(state, dict_, scalar)
+ state.get_impl(self.key).set_committed_value(state, dict_, scalar)
def load_scalar_from_subq_existing_row(state, dict_, row):
if self.key not in dict_:
load_scalar_from_subq(state, dict_, row)
- populators["new"].append(
- (self.key, load_scalar_from_subq))
+ populators["new"].append((self.key, load_scalar_from_subq))
populators["existing"].append(
- (self.key, load_scalar_from_subq_existing_row))
+ (self.key, load_scalar_from_subq_existing_row)
+ )
if context.invoke_all_eagers:
populators["eager"].append((self.key, collections.loader))
@@ -1292,7 +1397,7 @@ class JoinedLoader(AbstractRelationshipLoader):
"""
- __slots__ = 'join_depth', '_aliased_class_pool'
+ __slots__ = "join_depth", "_aliased_class_pool"
def __init__(self, parent, strategy_key):
super(JoinedLoader, self).__init__(parent, strategy_key)
@@ -1300,14 +1405,22 @@ class JoinedLoader(AbstractRelationshipLoader):
self._aliased_class_pool = []
def init_class_attribute(self, mapper):
- self.parent_property.\
- _get_strategy((("lazy", "select"),)).init_class_attribute(mapper)
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
def setup_query(
- self, context, entity, path, loadopt, adapter,
- column_collection=None, parentmapper=None,
- chained_from_outerjoin=False,
- **kwargs):
+ self,
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection=None,
+ parentmapper=None,
+ chained_from_outerjoin=False,
+ **kwargs
+ ):
"""Add a left outer join to the statement that's being constructed."""
if not context.query._enable_eagerloads:
@@ -1319,15 +1432,16 @@ class JoinedLoader(AbstractRelationshipLoader):
with_polymorphic = None
- user_defined_adapter = self._init_user_defined_eager_proc(
- loadopt, context) if loadopt else False
+ user_defined_adapter = (
+ self._init_user_defined_eager_proc(loadopt, context)
+ if loadopt
+ else False
+ )
if user_defined_adapter is not False:
- clauses, adapter, add_to_collection = \
- self._setup_query_on_user_defined_adapter(
- context, entity, path, adapter,
- user_defined_adapter
- )
+ clauses, adapter, add_to_collection = self._setup_query_on_user_defined_adapter(
+ context, entity, path, adapter, user_defined_adapter
+ )
else:
# if not via query option, check for
# a cycle
@@ -1338,16 +1452,19 @@ class JoinedLoader(AbstractRelationshipLoader):
elif path.contains_mapper(self.mapper):
return
- clauses, adapter, add_to_collection, chained_from_outerjoin = \
- self._generate_row_adapter(
- context, entity, path, loadopt, adapter,
- column_collection, parentmapper, chained_from_outerjoin
- )
+ clauses, adapter, add_to_collection, chained_from_outerjoin = self._generate_row_adapter(
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ parentmapper,
+ chained_from_outerjoin,
+ )
with_poly_info = path.get(
- context.attributes,
- "path_with_polymorphic",
- None
+ context.attributes, "path_with_polymorphic", None
)
if with_poly_info is not None:
with_polymorphic = with_poly_info.with_polymorphic_mappers
@@ -1357,14 +1474,20 @@ class JoinedLoader(AbstractRelationshipLoader):
path = path[self.mapper]
loading._setup_entity_query(
- context, self.mapper, entity,
- path, clauses, add_to_collection,
+ context,
+ self.mapper,
+ entity,
+ path,
+ clauses,
+ add_to_collection,
with_polymorphic=with_polymorphic,
parentmapper=self.mapper,
- chained_from_outerjoin=chained_from_outerjoin)
+ chained_from_outerjoin=chained_from_outerjoin,
+ )
- if with_poly_info is not None and \
- None in set(context.secondary_columns):
+ if with_poly_info is not None and None in set(
+ context.secondary_columns
+ ):
raise sa_exc.InvalidRequestError(
"Detected unaliased columns when generating joined "
"load. Make sure to use aliased=True or flat=True "
@@ -1383,8 +1506,8 @@ class JoinedLoader(AbstractRelationshipLoader):
# the option applies. check if the "user_defined_eager_row_processor"
# has been built up.
adapter = path.get(
- context.attributes,
- "user_defined_eager_row_processor", False)
+ context.attributes, "user_defined_eager_row_processor", False
+ )
if adapter is not False:
# just return it
return adapter
@@ -1394,38 +1517,39 @@ class JoinedLoader(AbstractRelationshipLoader):
root_mapper, prop = path[-2:]
- #from .mapper import Mapper
- #from .interfaces import MapperProperty
- #assert isinstance(root_mapper, Mapper)
- #assert isinstance(prop, MapperProperty)
+ # from .mapper import Mapper
+ # from .interfaces import MapperProperty
+ # assert isinstance(root_mapper, Mapper)
+ # assert isinstance(prop, MapperProperty)
if alias is not None:
if isinstance(alias, str):
alias = prop.target.alias(alias)
adapter = sql_util.ColumnAdapter(
- alias,
- equivalents=prop.mapper._equivalent_columns)
+ alias, equivalents=prop.mapper._equivalent_columns
+ )
else:
if path.contains(context.attributes, "path_with_polymorphic"):
with_poly_info = path.get(
- context.attributes,
- "path_with_polymorphic")
+ context.attributes, "path_with_polymorphic"
+ )
adapter = orm_util.ORMAdapter(
with_poly_info.entity,
- equivalents=prop.mapper._equivalent_columns)
+ equivalents=prop.mapper._equivalent_columns,
+ )
else:
adapter = context.query._polymorphic_adapters.get(
- prop.mapper, None)
+ prop.mapper, None
+ )
path.set(
- context.attributes,
- "user_defined_eager_row_processor",
- adapter)
+ context.attributes, "user_defined_eager_row_processor", adapter
+ )
return adapter
def _setup_query_on_user_defined_adapter(
- self, context, entity,
- path, adapter, user_defined_adapter):
+ self, context, entity, path, adapter, user_defined_adapter
+ ):
# apply some more wrapping to the "user defined adapter"
# if we are setting up the query for SQL render.
@@ -1434,13 +1558,17 @@ class JoinedLoader(AbstractRelationshipLoader):
if adapter and user_defined_adapter:
user_defined_adapter = user_defined_adapter.wrap(adapter)
path.set(
- context.attributes, "user_defined_eager_row_processor",
- user_defined_adapter)
+ context.attributes,
+ "user_defined_eager_row_processor",
+ user_defined_adapter,
+ )
elif adapter:
user_defined_adapter = adapter
path.set(
- context.attributes, "user_defined_eager_row_processor",
- user_defined_adapter)
+ context.attributes,
+ "user_defined_eager_row_processor",
+ user_defined_adapter,
+ )
add_to_collection = context.primary_columns
return user_defined_adapter, adapter, add_to_collection
@@ -1450,7 +1578,7 @@ class JoinedLoader(AbstractRelationshipLoader):
# we need one unique AliasedClass per query per appearance of our
# entity in the query.
- key = ('joinedloader_ac', self)
+ key = ("joinedloader_ac", self)
if key not in context.attributes:
context.attributes[key] = idx = 0
else:
@@ -1458,9 +1586,8 @@ class JoinedLoader(AbstractRelationshipLoader):
if idx >= len(self._aliased_class_pool):
to_adapt = orm_util.AliasedClass(
- self.mapper,
- flat=True,
- use_mapper_path=True)
+ self.mapper, flat=True, use_mapper_path=True
+ )
# load up the .columns collection on the Alias() before
# the object becomes shared among threads. this prevents
# races for column identities.
@@ -1471,13 +1598,18 @@ class JoinedLoader(AbstractRelationshipLoader):
return self._aliased_class_pool[idx]
def _generate_row_adapter(
- self,
- context, entity, path, loadopt, adapter,
- column_collection, parentmapper, chained_from_outerjoin):
+ self,
+ context,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ parentmapper,
+ chained_from_outerjoin,
+ ):
with_poly_info = path.get(
- context.attributes,
- "path_with_polymorphic",
- None
+ context.attributes, "path_with_polymorphic", None
)
if with_poly_info:
to_adapt = with_poly_info.entity
@@ -1489,8 +1621,9 @@ class JoinedLoader(AbstractRelationshipLoader):
orm_util.ORMAdapter,
to_adapt,
equivalents=self.mapper._equivalent_columns,
- adapt_required=True, allow_label_resolve=False,
- anonymize_labels=True
+ adapt_required=True,
+ allow_label_resolve=False,
+ anonymize_labels=True,
)
assert clauses.aliased_class is not None
@@ -1499,8 +1632,7 @@ class JoinedLoader(AbstractRelationshipLoader):
context.multi_row_eager_loaders = True
innerjoin = (
- loadopt.local_opts.get(
- 'innerjoin', self.parent_property.innerjoin)
+ loadopt.local_opts.get("innerjoin", self.parent_property.innerjoin)
if loadopt is not None
else self.parent_property.innerjoin
)
@@ -1512,9 +1644,15 @@ class JoinedLoader(AbstractRelationshipLoader):
context.create_eager_joins.append(
(
- self._create_eager_join, context,
- entity, path, adapter,
- parentmapper, clauses, innerjoin, chained_from_outerjoin
+ self._create_eager_join,
+ context,
+ entity,
+ path,
+ adapter,
+ parentmapper,
+ clauses,
+ innerjoin,
+ chained_from_outerjoin,
)
)
@@ -1524,9 +1662,16 @@ class JoinedLoader(AbstractRelationshipLoader):
return clauses, adapter, add_to_collection, chained_from_outerjoin
def _create_eager_join(
- self, context, entity,
- path, adapter, parentmapper,
- clauses, innerjoin, chained_from_outerjoin):
+ self,
+ context,
+ entity,
+ path,
+ adapter,
+ parentmapper,
+ clauses,
+ innerjoin,
+ chained_from_outerjoin,
+ ):
if parentmapper is None:
localparent = entity.mapper
@@ -1536,16 +1681,21 @@ class JoinedLoader(AbstractRelationshipLoader):
# whether or not the Query will wrap the selectable in a subquery,
# and then attach eager load joins to that (i.e., in the case of
# LIMIT/OFFSET etc.)
- should_nest_selectable = context.multi_row_eager_loaders and \
- context.query._should_nest_selectable
+ should_nest_selectable = (
+ context.multi_row_eager_loaders
+ and context.query._should_nest_selectable
+ )
entity_key = None
- if entity not in context.eager_joins and \
- not should_nest_selectable and \
- context.from_clause:
+ if (
+ entity not in context.eager_joins
+ and not should_nest_selectable
+ and context.from_clause
+ ):
indexes = sql_util.find_left_clause_that_matches_given(
- context.from_clause, entity.selectable)
+ context.from_clause, entity.selectable
+ )
if len(indexes) > 1:
# for the eager load case, I can't reproduce this right
@@ -1553,7 +1703,8 @@ class JoinedLoader(AbstractRelationshipLoader):
raise sa_exc.InvalidRequestError(
"Can't identify which entity in which to joined eager "
"load from. Please use an exact match when specifying "
- "the join path.")
+ "the join path."
+ )
if indexes:
clause = context.from_clause[indexes[0]]
@@ -1569,29 +1720,27 @@ class JoinedLoader(AbstractRelationshipLoader):
towrap = context.eager_joins.setdefault(entity_key, default_towrap)
if adapter:
- if getattr(adapter, 'aliased_class', None):
+ if getattr(adapter, "aliased_class", None):
# joining from an adapted entity. The adapted entity
# might be a "with_polymorphic", so resolve that to our
# specific mapper's entity before looking for our attribute
# name on it.
- efm = inspect(adapter.aliased_class).\
- _entity_for_mapper(
- localparent
- if localparent.isa(self.parent) else self.parent)
+ efm = inspect(adapter.aliased_class)._entity_for_mapper(
+ localparent
+ if localparent.isa(self.parent)
+ else self.parent
+ )
# look for our attribute on the adapted entity, else fall back
# to our straight property
- onclause = getattr(
- efm.entity, self.key,
- self.parent_property)
+ onclause = getattr(efm.entity, self.key, self.parent_property)
else:
onclause = getattr(
orm_util.AliasedClass(
- self.parent,
- adapter.selectable,
- use_mapper_path=True
+ self.parent, adapter.selectable, use_mapper_path=True
),
- self.key, self.parent_property
+ self.key,
+ self.parent_property,
)
else:
@@ -1600,9 +1749,10 @@ class JoinedLoader(AbstractRelationshipLoader):
assert clauses.aliased_class is not None
attach_on_outside = (
- not chained_from_outerjoin or
- not innerjoin or innerjoin == 'unnested' or
- entity.entity_zero.represents_outer_join
+ not chained_from_outerjoin
+ or not innerjoin
+ or innerjoin == "unnested"
+ or entity.entity_zero.represents_outer_join
)
if attach_on_outside:
@@ -1611,16 +1761,17 @@ class JoinedLoader(AbstractRelationshipLoader):
towrap,
clauses.aliased_class,
onclause,
- isouter=not innerjoin or
- entity.entity_zero.represents_outer_join or
- (
- chained_from_outerjoin and isinstance(towrap, sql.Join)
- ), _left_memo=self.parent, _right_memo=self.mapper
+ isouter=not innerjoin
+ or entity.entity_zero.represents_outer_join
+ or (chained_from_outerjoin and isinstance(towrap, sql.Join)),
+ _left_memo=self.parent,
+ _right_memo=self.mapper,
)
else:
# all other cases are innerjoin=='nested' approach
eagerjoin = self._splice_nested_inner_join(
- path, towrap, clauses, onclause)
+ path, towrap, clauses, onclause
+ )
context.eager_joins[entity_key] = eagerjoin
@@ -1636,22 +1787,21 @@ class JoinedLoader(AbstractRelationshipLoader):
# This has the effect
# of "undefering" those columns.
for col in sql_util._find_columns(
- self.parent_property.primaryjoin):
+ self.parent_property.primaryjoin
+ ):
if localparent.mapped_table.c.contains_column(col):
if adapter:
col = adapter.columns[col]
context.primary_columns.append(col)
if self.parent_property.order_by:
- context.eager_order_by += eagerjoin._target_adapter.\
- copy_and_process(
- util.to_list(
- self.parent_property.order_by
- )
- )
+ context.eager_order_by += eagerjoin._target_adapter.copy_and_process(
+ util.to_list(self.parent_property.order_by)
+ )
def _splice_nested_inner_join(
- self, path, join_obj, clauses, onclause, splicing=False):
+ self, path, join_obj, clauses, onclause, splicing=False
+ ):
if splicing is False:
# first call is always handed a join object
@@ -1664,28 +1814,31 @@ class JoinedLoader(AbstractRelationshipLoader):
elif not isinstance(join_obj, orm_util._ORMJoin):
if path[-2] is splicing:
return orm_util._ORMJoin(
- join_obj, clauses.aliased_class,
- onclause, isouter=False,
+ join_obj,
+ clauses.aliased_class,
+ onclause,
+ isouter=False,
_left_memo=splicing,
- _right_memo=path[-1].mapper
+ _right_memo=path[-1].mapper,
)
else:
# only here if splicing == True
return None
target_join = self._splice_nested_inner_join(
- path, join_obj.right, clauses,
- onclause, join_obj._right_memo)
+ path, join_obj.right, clauses, onclause, join_obj._right_memo
+ )
if target_join is None:
right_splice = False
target_join = self._splice_nested_inner_join(
- path, join_obj.left, clauses,
- onclause, join_obj._left_memo)
+ path, join_obj.left, clauses, onclause, join_obj._left_memo
+ )
if target_join is None:
# should only return None when recursively called,
# e.g. splicing==True
- assert splicing is not False, \
- "assertion failed attempting to produce joined eager loads"
+ assert (
+ splicing is not False
+ ), "assertion failed attempting to produce joined eager loads"
return None
else:
right_splice = True
@@ -1698,21 +1851,30 @@ class JoinedLoader(AbstractRelationshipLoader):
eagerjoin = join_obj._splice_into_center(target_join)
else:
eagerjoin = orm_util._ORMJoin(
- join_obj.left, target_join,
- join_obj.onclause, isouter=join_obj.isouter,
- _left_memo=join_obj._left_memo)
+ join_obj.left,
+ target_join,
+ join_obj.onclause,
+ isouter=join_obj.isouter,
+ _left_memo=join_obj._left_memo,
+ )
else:
eagerjoin = orm_util._ORMJoin(
- target_join, join_obj.right,
- join_obj.onclause, isouter=join_obj.isouter,
- _right_memo=join_obj._right_memo)
+ target_join,
+ join_obj.right,
+ join_obj.onclause,
+ isouter=join_obj.isouter,
+ _right_memo=join_obj._right_memo,
+ )
eagerjoin._target_adapter = target_join._target_adapter
return eagerjoin
def _create_eager_adapter(self, context, result, adapter, path, loadopt):
- user_defined_adapter = self._init_user_defined_eager_proc(
- loadopt, context) if loadopt else False
+ user_defined_adapter = (
+ self._init_user_defined_eager_proc(loadopt, context)
+ if loadopt
+ else False
+ )
if user_defined_adapter is not False:
decorator = user_defined_adapter
@@ -1736,21 +1898,19 @@ class JoinedLoader(AbstractRelationshipLoader):
return False
def create_row_processor(
- self, context, path, loadopt, mapper,
- result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
if not self.parent.class_manager[self.key].impl.supports_population:
raise sa_exc.InvalidRequestError(
"'%s' does not support object "
- "population - eager loading cannot be applied." %
- self
+ "population - eager loading cannot be applied." % self
)
our_path = path[self.parent_property]
eager_adapter = self._create_eager_adapter(
- context,
- result,
- adapter, our_path, loadopt)
+ context, result, adapter, our_path, loadopt
+ )
if eager_adapter is not False:
key = self.key
@@ -1760,25 +1920,28 @@ class JoinedLoader(AbstractRelationshipLoader):
context,
result,
our_path[self.mapper],
- eager_adapter)
+ eager_adapter,
+ )
if not self.uselist:
self._create_scalar_loader(context, key, _instance, populators)
else:
self._create_collection_loader(
- context, key, _instance, populators)
+ context, key, _instance, populators
+ )
else:
- self.parent_property._get_strategy((("lazy", "select"),)).\
- create_row_processor(
- context, path, loadopt,
- mapper, result, adapter, populators)
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).create_row_processor(
+ context, path, loadopt, mapper, result, adapter, populators
+ )
def _create_collection_loader(self, context, key, _instance, populators):
def load_collection_from_joined_new_row(state, dict_, row):
- collection = attributes.init_state_collection(
- state, dict_, key)
- result_list = util.UniqueAppender(collection,
- 'append_without_event')
+ collection = attributes.init_state_collection(state, dict_, key)
+ result_list = util.UniqueAppender(
+ collection, "append_without_event"
+ )
context.attributes[(state, key)] = result_list
inst = _instance(row)
if inst is not None:
@@ -1793,10 +1956,11 @@ class JoinedLoader(AbstractRelationshipLoader):
# is used; the same instance may be present in two
# distinct sets of result columns
collection = attributes.init_state_collection(
- state, dict_, key)
+ state, dict_, key
+ )
result_list = util.UniqueAppender(
- collection,
- 'append_without_event')
+ collection, "append_without_event"
+ )
context.attributes[(state, key)] = result_list
inst = _instance(row)
if inst is not None:
@@ -1805,12 +1969,16 @@ class JoinedLoader(AbstractRelationshipLoader):
def load_collection_from_joined_exec(state, dict_, row):
_instance(row)
- populators["new"].append((self.key, load_collection_from_joined_new_row))
+ populators["new"].append(
+ (self.key, load_collection_from_joined_new_row)
+ )
populators["existing"].append(
- (self.key, load_collection_from_joined_existing_row))
+ (self.key, load_collection_from_joined_existing_row)
+ )
if context.invoke_all_eagers:
populators["eager"].append(
- (self.key, load_collection_from_joined_exec))
+ (self.key, load_collection_from_joined_exec)
+ )
def _create_scalar_loader(self, context, key, _instance, populators):
def load_scalar_from_joined_new_row(state, dict_, row):
@@ -1829,7 +1997,8 @@ class JoinedLoader(AbstractRelationshipLoader):
util.warn(
"Multiple rows returned with "
"uselist=False for eagerly-loaded attribute '%s' "
- % self)
+ % self
+ )
else:
# this case is when one row has multiple loads of the
# same entity (e.g. via aliasing), one has an attribute
@@ -1841,17 +2010,25 @@ class JoinedLoader(AbstractRelationshipLoader):
populators["new"].append((self.key, load_scalar_from_joined_new_row))
populators["existing"].append(
- (self.key, load_scalar_from_joined_existing_row))
+ (self.key, load_scalar_from_joined_existing_row)
+ )
if context.invoke_all_eagers:
- populators["eager"].append((self.key, load_scalar_from_joined_exec))
+ populators["eager"].append(
+ (self.key, load_scalar_from_joined_exec)
+ )
@log.class_logger
@properties.RelationshipProperty.strategy_for(lazy="selectin")
class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
__slots__ = (
- 'join_depth', 'omit_join', '_parent_alias', '_in_expr',
- '_pk_cols', '_zero_idx', '_bakery'
+ "join_depth",
+ "omit_join",
+ "_parent_alias",
+ "_in_expr",
+ "_pk_cols",
+ "_zero_idx",
+ "_bakery",
)
_chunksize = 500
@@ -1864,11 +2041,12 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
self.omit_join = self.parent_property.omit_join
else:
lazyloader = self.parent_property._get_strategy(
- (("lazy", "select"),))
+ (("lazy", "select"),)
+ )
self.omit_join = self.parent._get_clause[0].compare(
lazyloader._rev_lazywhere,
use_proxies=True,
- equivalents=self.parent._equivalent_columns
+ equivalents=self.parent._equivalent_columns,
)
if self.omit_join:
self._init_for_omit_join()
@@ -1886,8 +2064,8 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
)
self._pk_cols = fk_cols = [
- pk_to_fk[col]
- for col in self.parent.primary_key if col in pk_to_fk]
+ pk_to_fk[col] for col in self.parent.primary_key if col in pk_to_fk
+ ]
if len(fk_cols) > 1:
self._in_expr = sql.tuple_(*fk_cols)
self._zero_idx = False
@@ -1899,7 +2077,8 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
self._parent_alias = aliased(self.parent.class_)
pa_insp = inspect(self._parent_alias)
self._pk_cols = pk_cols = [
- pa_insp._adapt_element(col) for col in self.parent.primary_key]
+ pa_insp._adapt_element(col) for col in self.parent.primary_key
+ ]
if len(pk_cols) > 1:
self._in_expr = sql.tuple_(*pk_cols)
self._zero_idx = False
@@ -1908,26 +2087,26 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
self._zero_idx = True
def init_class_attribute(self, mapper):
- self.parent_property.\
- _get_strategy((("lazy", "select"),)).\
- init_class_attribute(mapper)
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
@util.dependencies("sqlalchemy.ext.baked")
def _memoized_attr__bakery(self, baked):
return baked.bakery(size=50)
def create_row_processor(
- self, context, path, loadopt, mapper,
- result, adapter, populators):
+ self, context, path, loadopt, mapper, result, adapter, populators
+ ):
if not self.parent.class_manager[self.key].impl.supports_population:
raise sa_exc.InvalidRequestError(
"'%s' does not support object "
- "population - eager loading cannot be applied." %
- self
+ "population - eager loading cannot be applied." % self
)
selectin_path = (
- context.query._current_path or orm_util.PathRegistry.root) + path
+ context.query._current_path or orm_util.PathRegistry.root
+ ) + path
if not orm_util._entity_isa(path[-1], self.parent):
return
@@ -1941,8 +2120,8 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
# build up a path indicating the path from the leftmost
# entity to the thing we're subquery loading.
with_poly_info = path_w_prop.get(
- context.attributes,
- "path_with_polymorphic", None)
+ context.attributes, "path_with_polymorphic", None
+ )
if with_poly_info is not None:
effective_entity = with_poly_info.entity
@@ -1957,19 +2136,24 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
return
loading.PostLoad.callable_for_path(
- context, selectin_path, self.parent, self.key,
- self._load_for_path, effective_entity)
+ context,
+ selectin_path,
+ self.parent,
+ self.key,
+ self._load_for_path,
+ effective_entity,
+ )
@util.dependencies("sqlalchemy.ext.baked")
def _load_for_path(
- self, baked, context, path, states, load_only, effective_entity):
+ self, baked, context, path, states, load_only, effective_entity
+ ):
if load_only and self.key not in load_only:
return
our_states = [
- (state.key[1], state, overwrite)
- for state, overwrite in states
+ (state.key[1], state, overwrite) for state, overwrite in states
]
pk_cols = self._pk_cols
@@ -1984,17 +2168,15 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
# parent entity and do not need adaption.
insp = inspect(effective_entity)
if insp.is_aliased_class:
- pk_cols = [
- insp._adapt_element(col)
- for col in pk_cols
- ]
+ pk_cols = [insp._adapt_element(col) for col in pk_cols]
in_expr = insp._adapt_element(in_expr)
pk_cols = [insp._adapt_element(col) for col in pk_cols]
q = self._bakery(
lambda session: session.query(
- query.Bundle("pk", *pk_cols), effective_entity,
- ), self
+ query.Bundle("pk", *pk_cols), effective_entity
+ ),
+ self,
)
if self.omit_join:
@@ -2012,60 +2194,53 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
q.add_criteria(
lambda q: q.select_from(pa).join(
getattr(pa, self.parent_property.key).of_type(
- effective_entity)
+ effective_entity
+ )
)
)
q.add_criteria(
lambda q: q.filter(
- in_expr.in_(
- sql.bindparam("primary_keys", expanding=True))
- ).order_by(*pk_cols))
+ in_expr.in_(sql.bindparam("primary_keys", expanding=True))
+ ).order_by(*pk_cols)
+ )
orig_query = context.query
q._add_lazyload_options(
- orig_query._with_options,
- path[self.parent_property]
+ orig_query._with_options, path[self.parent_property]
)
if orig_query._populate_existing:
- q.add_criteria(
- lambda q: q.populate_existing()
- )
+ q.add_criteria(lambda q: q.populate_existing())
if self.parent_property.order_by:
if self.omit_join:
eager_order_by = self.parent_property.order_by
if insp.is_aliased_class:
eager_order_by = [
- insp._adapt_element(elem) for elem in
- eager_order_by
+ insp._adapt_element(elem) for elem in eager_order_by
]
- q.add_criteria(
- lambda q: q.order_by(*eager_order_by)
- )
+ q.add_criteria(lambda q: q.order_by(*eager_order_by))
else:
+
def _setup_outermost_orderby(q):
# imitate the same method that subquery eager loading uses,
# looking for the adapted "secondary" table
eagerjoin = q._from_obj[0]
- eager_order_by = \
- eagerjoin._target_adapter.\
- copy_and_process(
- util.to_list(self.parent_property.order_by)
- )
+ eager_order_by = eagerjoin._target_adapter.copy_and_process(
+ util.to_list(self.parent_property.order_by)
+ )
return q.order_by(*eager_order_by)
- q.add_criteria(
- _setup_outermost_orderby
- )
+
+ q.add_criteria(_setup_outermost_orderby)
uselist = self.uselist
_empty_result = () if uselist else None
while our_states:
- chunk = our_states[0:self._chunksize]
- our_states = our_states[self._chunksize:]
+ chunk = our_states[0 : self._chunksize]
+ our_states = our_states[self._chunksize :]
data = {
k: [vv[1] for vv in v]
@@ -2073,9 +2248,10 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
q(context.session).params(
primary_keys=[
key[0] if self._zero_idx else key
- for key, state, overwrite in chunk]
+ for key, state, overwrite in chunk
+ ]
),
- lambda x: x[0]
+ lambda x: x[0],
)
}
@@ -2091,13 +2267,15 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
util.warn(
"Multiple rows returned with "
"uselist=False for eagerly-loaded "
- "attribute '%s' "
- % self)
+ "attribute '%s' " % self
+ )
state.get_impl(self.key).set_committed_value(
- state, state.dict, collection[0])
+ state, state.dict, collection[0]
+ )
else:
state.get_impl(self.key).set_committed_value(
- state, state.dict, collection)
+ state, state.dict, collection
+ )
def single_parent_validator(desc, prop):
@@ -2108,8 +2286,8 @@ def single_parent_validator(desc, prop):
raise sa_exc.InvalidRequestError(
"Instance %s is already associated with an instance "
"of %s via its %s attribute, and is only allowed a "
- "single parent." %
- (orm_util.instance_str(value), state.class_, prop)
+ "single parent."
+ % (orm_util.instance_str(value), state.class_, prop)
)
return value
@@ -2120,8 +2298,6 @@ def single_parent_validator(desc, prop):
return _do_check(state, value, oldvalue, initiator)
event.listen(
- desc, 'append', append, raw=True, retval=True,
- active_history=True)
- event.listen(
- desc, 'set', set_, raw=True, retval=True,
- active_history=True)
+ desc, "append", append, raw=True, retval=True, active_history=True
+ )
+ event.listen(desc, "set", set_, raw=True, retval=True, active_history=True)
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index f0d209110..b2f6bcb11 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -13,11 +13,19 @@ from .attributes import QueryableAttribute
from .. import util
from ..sql.base import _generative, Generative
from .. import exc as sa_exc, inspect
-from .base import _is_aliased_class, _class_to_mapper, _is_mapped_class, \
- InspectionAttr
+from .base import (
+ _is_aliased_class,
+ _class_to_mapper,
+ _is_mapped_class,
+ InspectionAttr,
+)
from . import util as orm_util
-from .path_registry import PathRegistry, TokenRegistry, \
- _WILDCARD_TOKEN, _DEFAULT_TOKEN
+from .path_registry import (
+ PathRegistry,
+ TokenRegistry,
+ _WILDCARD_TOKEN,
+ _DEFAULT_TOKEN,
+)
class Load(Generative, MapperOption):
@@ -94,12 +102,14 @@ class Load(Generative, MapperOption):
if (
# means loader_path and path are unrelated,
# this does not need to be part of a cache key
- chopped is None
+ chopped
+ is None
) or (
# means no additional path with loader_path + path
# and the endpoint isn't using of_type so isn't modified
# into an alias or other unsafe entity
- not chopped and not obj._of_type
+ not chopped
+ and not obj._of_type
):
continue
@@ -124,12 +134,18 @@ class Load(Generative, MapperOption):
serialized.append(
(
- tuple(serialized_path) +
- (obj.strategy or ()) +
- (tuple([
- (key, obj.local_opts[key])
- for key in sorted(obj.local_opts)
- ]) if obj.local_opts else ())
+ tuple(serialized_path)
+ + (obj.strategy or ())
+ + (
+ tuple(
+ [
+ (key, obj.local_opts[key])
+ for key in sorted(obj.local_opts)
+ ]
+ )
+ if obj.local_opts
+ else ()
+ )
)
)
if not serialized:
@@ -170,12 +186,13 @@ class Load(Generative, MapperOption):
if raiseerr and not path.has_entity:
if isinstance(path, TokenRegistry):
raise sa_exc.ArgumentError(
- "Wildcard token cannot be followed by another entity")
+ "Wildcard token cannot be followed by another entity"
+ )
else:
raise sa_exc.ArgumentError(
"Attribute '%s' of entity '%s' does not "
- "refer to a mapped entity" %
- (path.prop.key, path.parent.entity)
+ "refer to a mapped entity"
+ % (path.prop.key, path.parent.entity)
)
if isinstance(attr, util.string_types):
@@ -201,8 +218,7 @@ class Load(Generative, MapperOption):
if raiseerr:
raise sa_exc.ArgumentError(
"Can't find property named '%s' on the "
- "mapped entity %s in this Query. " % (
- attr, ent)
+ "mapped entity %s in this Query. " % (attr, ent)
)
else:
return None
@@ -215,7 +231,8 @@ class Load(Generative, MapperOption):
if raiseerr:
raise sa_exc.ArgumentError(
"Attribute '%s' does not "
- "link from element '%s'" % (attr, path.entity))
+ "link from element '%s'" % (attr, path.entity)
+ )
else:
return None
else:
@@ -225,22 +242,26 @@ class Load(Generative, MapperOption):
if raiseerr:
raise sa_exc.ArgumentError(
"Attribute '%s' does not "
- "link from element '%s'" % (attr, path.entity))
+ "link from element '%s'" % (attr, path.entity)
+ )
else:
return None
- if getattr(attr, '_of_type', None):
+ if getattr(attr, "_of_type", None):
ac = attr._of_type
ext_info = of_type_info = inspect(ac)
existing = path.entity_path[prop].get(
- self.context, "path_with_polymorphic")
+ self.context, "path_with_polymorphic"
+ )
if not ext_info.is_aliased_class:
ac = orm_util.with_polymorphic(
ext_info.mapper.base_mapper,
- ext_info.mapper, aliased=True,
+ ext_info.mapper,
+ aliased=True,
_use_mapper_path=True,
- _existing_alias=existing)
+ _existing_alias=existing,
+ )
ext_info = inspect(ac)
elif not ext_info.with_polymorphic_mappers:
ext_info = orm_util.AliasedInsp(
@@ -253,11 +274,12 @@ class Load(Generative, MapperOption):
ext_info._base_alias,
ext_info._use_mapper_path,
ext_info._adapt_on_names,
- ext_info.represents_outer_join
+ ext_info.represents_outer_join,
)
path.entity_path[prop].set(
- self.context, "path_with_polymorphic", ext_info)
+ self.context, "path_with_polymorphic", ext_info
+ )
# the path here will go into the context dictionary and
# needs to match up to how the class graph is traversed.
@@ -280,7 +302,7 @@ class Load(Generative, MapperOption):
return path
def __str__(self):
- return "Load(strategy=%r)" % (self.strategy, )
+ return "Load(strategy=%r)" % (self.strategy,)
def _coerce_strat(self, strategy):
if strategy is not None:
@@ -289,7 +311,8 @@ class Load(Generative, MapperOption):
@_generative
def set_relationship_strategy(
- self, attr, strategy, propagate_to_loaders=True):
+ self, attr, strategy, propagate_to_loaders=True
+ ):
strategy = self._coerce_strat(strategy)
self.is_class_strategy = False
@@ -365,12 +388,18 @@ class Load(Generative, MapperOption):
if effective_path.is_token:
for path in effective_path.generate_for_superclasses():
self._set_for_path(
- self.context, path, replace=True,
- merge_opts=self.is_opts_only)
+ self.context,
+ path,
+ replace=True,
+ merge_opts=self.is_opts_only,
+ )
else:
self._set_for_path(
- self.context, effective_path, replace=True,
- merge_opts=self.is_opts_only)
+ self.context,
+ effective_path,
+ replace=True,
+ merge_opts=self.is_opts_only,
+ )
def __getstate__(self):
d = self.__dict__.copy()
@@ -389,21 +418,26 @@ class Load(Generative, MapperOption):
# TODO: this is approximated from the _UnboundLoad
# version and probably has issues, not fully covered.
- if i == 0 and c_token.endswith(':' + _DEFAULT_TOKEN):
+ if i == 0 and c_token.endswith(":" + _DEFAULT_TOKEN):
return to_chop
- elif c_token != 'relationship:%s' % (_WILDCARD_TOKEN,) and \
- c_token != p_token.key:
+ elif (
+ c_token != "relationship:%s" % (_WILDCARD_TOKEN,)
+ and c_token != p_token.key
+ ):
return None
if c_token is p_token:
continue
- elif isinstance(c_token, InspectionAttr) and \
- c_token.is_mapper and p_token.is_mapper and \
- c_token.isa(p_token):
+ elif (
+ isinstance(c_token, InspectionAttr)
+ and c_token.is_mapper
+ and p_token.is_mapper
+ and c_token.isa(p_token)
+ ):
continue
else:
return None
- return to_chop[i + 1:]
+ return to_chop[i + 1 :]
class _UnboundLoad(Load):
@@ -431,9 +465,7 @@ class _UnboundLoad(Load):
if local_elem is not val_elem:
break
else:
- opt = val._bind_loader(
- [path.path[0]],
- None, None, False)
+ opt = val._bind_loader([path.path[0]], None, None, False)
if opt:
c_key = opt._generate_cache_key(path)
if c_key is False:
@@ -449,26 +481,29 @@ class _UnboundLoad(Load):
self._to_bind.append(self)
def _generate_path(self, path, attr, wildcard_key):
- if wildcard_key and isinstance(attr, util.string_types) and \
- attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN):
+ if (
+ wildcard_key
+ and isinstance(attr, util.string_types)
+ and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN)
+ ):
if attr == _DEFAULT_TOKEN:
self.propagate_to_loaders = False
attr = "%s:%s" % (wildcard_key, attr)
if path and _is_mapped_class(path[-1]) and not self.is_class_strategy:
path = path[0:-1]
if attr:
- path = path + (attr, )
+ path = path + (attr,)
self.path = path
return path
def __getstate__(self):
d = self.__dict__.copy()
- d['path'] = self._serialize_path(self.path, filter_aliased_class=True)
+ d["path"] = self._serialize_path(self.path, filter_aliased_class=True)
return d
def __setstate__(self, state):
ret = []
- for key in state['path']:
+ for key in state["path"]:
if isinstance(key, tuple):
if len(key) == 2:
# support legacy
@@ -482,17 +517,20 @@ class _UnboundLoad(Load):
ret.append(prop)
else:
ret.append(key)
- state['path'] = tuple(ret)
+ state["path"] = tuple(ret)
self.__dict__ = state
def _process(self, query, raiseerr):
- dedupes = query._attributes['_unbound_load_dedupes']
+ dedupes = query._attributes["_unbound_load_dedupes"]
for val in self._to_bind:
if val not in dedupes:
dedupes.add(val)
val._bind_loader(
[ent.entity_zero for ent in query._mapper_entities],
- query._current_path, query._attributes, raiseerr)
+ query._current_path,
+ query._attributes,
+ raiseerr,
+ )
@classmethod
def _from_keys(cls, meth, keys, chained, kw):
@@ -502,13 +540,14 @@ class _UnboundLoad(Load):
if isinstance(key, util.string_types):
# coerce fooload('*') into "default loader strategy"
if key == _WILDCARD_TOKEN:
- return (_DEFAULT_TOKEN, )
+ return (_DEFAULT_TOKEN,)
# coerce fooload(".*") into "wildcard on default entity"
elif key.startswith("." + _WILDCARD_TOKEN):
key = key[1:]
return key.split(".")
else:
return (key,)
+
all_tokens = [token for key in keys for token in _split_key(key)]
for token in all_tokens[0:-1]:
@@ -526,21 +565,24 @@ class _UnboundLoad(Load):
def _chop_path(self, to_chop, path):
i = -1
for i, (c_token, (p_entity, p_prop)) in enumerate(
- zip(to_chop, path.pairs())):
+ zip(to_chop, path.pairs())
+ ):
if isinstance(c_token, util.string_types):
- if i == 0 and c_token.endswith(':' + _DEFAULT_TOKEN):
+ if i == 0 and c_token.endswith(":" + _DEFAULT_TOKEN):
return to_chop
- elif c_token != 'relationship:%s' % (
- _WILDCARD_TOKEN,) and c_token != p_prop.key:
+ elif (
+ c_token != "relationship:%s" % (_WILDCARD_TOKEN,)
+ and c_token != p_prop.key
+ ):
return None
elif isinstance(c_token, PropComparator):
- if c_token.property is not p_prop or \
- (
- c_token._parententity is not p_entity and (
- not c_token._parententity.is_mapper or
- not c_token._parententity.isa(p_entity)
- )
- ):
+ if c_token.property is not p_prop or (
+ c_token._parententity is not p_entity
+ and (
+ not c_token._parententity.is_mapper
+ or not c_token._parententity.isa(p_entity)
+ )
+ ):
return None
else:
i += 1
@@ -551,15 +593,16 @@ class _UnboundLoad(Load):
ret = []
for token in path:
if isinstance(token, QueryableAttribute):
- if filter_aliased_class and token._of_type and \
- inspect(token._of_type).is_aliased_class:
- ret.append(
- (token._parentmapper.class_,
- token.key, None))
+ if (
+ filter_aliased_class
+ and token._of_type
+ and inspect(token._of_type).is_aliased_class
+ ):
+ ret.append((token._parentmapper.class_, token.key, None))
else:
ret.append(
- (token._parentmapper.class_, token.key,
- token._of_type))
+ (token._parentmapper.class_, token.key, token._of_type)
+ )
elif isinstance(token, PropComparator):
ret.append((token._parentmapper.class_, token.key, None))
else:
@@ -605,7 +648,7 @@ class _UnboundLoad(Load):
start_path = self.path
if self.is_class_strategy and current_path:
- start_path += (entities[0], )
+ start_path += (entities[0],)
# _current_path implies we're in a
# secondary load with an existing path
@@ -621,23 +664,20 @@ class _UnboundLoad(Load):
token = start_path[0]
if isinstance(token, util.string_types):
- entity = self._find_entity_basestring(
- entities, token, raiseerr)
+ entity = self._find_entity_basestring(entities, token, raiseerr)
elif isinstance(token, PropComparator):
prop = token.property
entity = self._find_entity_prop_comparator(
- entities,
- prop.key,
- token._parententity,
- raiseerr)
+ entities, prop.key, token._parententity, raiseerr
+ )
elif self.is_class_strategy and _is_mapped_class(token):
entity = inspect(token)
if entity not in entities:
entity = None
else:
raise sa_exc.ArgumentError(
- "mapper option expects "
- "string key or list of attributes")
+ "mapper option expects " "string key or list of attributes"
+ )
if not entity:
return
@@ -663,7 +703,8 @@ class _UnboundLoad(Load):
if not loader.is_class_strategy:
for token in start_path:
if not loader._generate_path(
- loader.path, token, None, raiseerr):
+ loader.path, token, None, raiseerr
+ ):
return
loader.local_opts.update(self.local_opts)
@@ -680,14 +721,18 @@ class _UnboundLoad(Load):
if effective_path.is_token:
for path in effective_path.generate_for_superclasses():
loader._set_for_path(
- context, path,
+ context,
+ path,
replace=not self._is_chain_link,
- merge_opts=self.is_opts_only)
+ merge_opts=self.is_opts_only,
+ )
else:
loader._set_for_path(
- context, effective_path,
+ context,
+ effective_path,
replace=not self._is_chain_link,
- merge_opts=self.is_opts_only)
+ merge_opts=self.is_opts_only,
+ )
return loader
@@ -704,28 +749,27 @@ class _UnboundLoad(Load):
if not list(entities):
raise sa_exc.ArgumentError(
"Query has only expression-based entities - "
- "can't find property named '%s'."
- % (token, )
+ "can't find property named '%s'." % (token,)
)
else:
raise sa_exc.ArgumentError(
"Can't find property '%s' on any entity "
"specified in this Query. Note the full path "
"from root (%s) to target entity must be specified."
- % (token, ",".join(str(x) for
- x in entities))
+ % (token, ",".join(str(x) for x in entities))
)
else:
return None
def _find_entity_basestring(self, entities, token, raiseerr):
- if token.endswith(':' + _WILDCARD_TOKEN):
+ if token.endswith(":" + _WILDCARD_TOKEN):
if len(list(entities)) != 1:
if raiseerr:
raise sa_exc.ArgumentError(
"Wildcard loader can only be used with exactly "
"one entity. Use Load(ent) to specify "
- "specific entities.")
+ "specific entities."
+ )
elif token.endswith(_DEFAULT_TOKEN):
raiseerr = False
@@ -738,8 +782,7 @@ class _UnboundLoad(Load):
if raiseerr:
raise sa_exc.ArgumentError(
"Query has only expression-based entities - "
- "can't find property named '%s'."
- % (token, )
+ "can't find property named '%s'." % (token,)
)
else:
return None
@@ -766,7 +809,9 @@ class loader_option(object):
See :func:`.orm.%(name)s` for usage examples.
-""" % {"name": self.name}
+""" % {
+ "name": self.name
+ }
fn.__doc__ = fn_doc
return self
@@ -783,7 +828,9 @@ See :func:`.orm.%(name)s` for usage examples.
%(name)s("someattribute").%(name)s("anotherattribute")
)
-""" % {"name": self.name}
+""" % {
+ "name": self.name
+ }
return self
@@ -840,23 +887,22 @@ def contains_eager(loadopt, attr, alias=None):
info = inspect(alias)
alias = info.selectable
- elif getattr(attr, '_of_type', None):
+ elif getattr(attr, "_of_type", None):
ot = inspect(attr._of_type)
alias = ot.selectable
cloned = loadopt.set_relationship_strategy(
- attr,
- {"lazy": "joined"},
- propagate_to_loaders=False
+ attr, {"lazy": "joined"}, propagate_to_loaders=False
)
- cloned.local_opts['eager_from_alias'] = alias
+ cloned.local_opts["eager_from_alias"] = alias
return cloned
@contains_eager._add_unbound_fn
def contains_eager(*keys, **kw):
return _UnboundLoad()._from_keys(
- _UnboundLoad.contains_eager, keys, True, kw)
+ _UnboundLoad.contains_eager, keys, True, kw
+ )
@loader_option()
@@ -894,12 +940,11 @@ def load_only(loadopt, *attrs):
"""
cloned = loadopt.set_column_strategy(
- attrs,
- {"deferred": False, "instrument": True}
+ attrs, {"deferred": False, "instrument": True}
+ )
+ cloned.set_column_strategy(
+ "*", {"deferred": True, "instrument": True}, {"undefer_pks": True}
)
- cloned.set_column_strategy("*",
- {"deferred": True, "instrument": True},
- {"undefer_pks": True})
return cloned
@@ -996,20 +1041,18 @@ def joinedload(loadopt, attr, innerjoin=None):
"""
loader = loadopt.set_relationship_strategy(attr, {"lazy": "joined"})
if innerjoin is not None:
- loader.local_opts['innerjoin'] = innerjoin
+ loader.local_opts["innerjoin"] = innerjoin
return loader
@joinedload._add_unbound_fn
def joinedload(*keys, **kw):
- return _UnboundLoad._from_keys(
- _UnboundLoad.joinedload, keys, False, kw)
+ return _UnboundLoad._from_keys(_UnboundLoad.joinedload, keys, False, kw)
@joinedload._add_unbound_all_fn
def joinedload_all(*keys, **kw):
- return _UnboundLoad._from_keys(
- _UnboundLoad.joinedload, keys, True, kw)
+ return _UnboundLoad._from_keys(_UnboundLoad.joinedload, keys, True, kw)
@loader_option()
@@ -1152,8 +1195,7 @@ def immediateload(loadopt, attr):
@immediateload._add_unbound_fn
def immediateload(*keys):
- return _UnboundLoad._from_keys(
- _UnboundLoad.immediateload, keys, False, {})
+ return _UnboundLoad._from_keys(_UnboundLoad.immediateload, keys, False, {})
@loader_option()
@@ -1213,7 +1255,8 @@ def raiseload(loadopt, attr, sql_only=False):
"""
return loadopt.set_relationship_strategy(
- attr, {"lazy": "raise_on_sql" if sql_only else "raise"})
+ attr, {"lazy": "raise_on_sql" if sql_only else "raise"}
+ )
@raiseload._add_unbound_fn
@@ -1251,10 +1294,7 @@ def defaultload(loadopt, attr):
:ref:`deferred_loading_w_multiple`
"""
- return loadopt.set_relationship_strategy(
- attr,
- None
- )
+ return loadopt.set_relationship_strategy(attr, None)
@defaultload._add_unbound_fn
@@ -1315,15 +1355,15 @@ def defer(loadopt, key):
"""
return loadopt.set_column_strategy(
- (key, ),
- {"deferred": True, "instrument": True}
+ (key,), {"deferred": True, "instrument": True}
)
@defer._add_unbound_fn
def defer(key, *addl_attrs):
return _UnboundLoad._from_keys(
- _UnboundLoad.defer, (key, ) + addl_attrs, False, {})
+ _UnboundLoad.defer, (key,) + addl_attrs, False, {}
+ )
@loader_option()
@@ -1362,15 +1402,15 @@ def undefer(loadopt, key):
"""
return loadopt.set_column_strategy(
- (key, ),
- {"deferred": False, "instrument": True}
+ (key,), {"deferred": False, "instrument": True}
)
@undefer._add_unbound_fn
def undefer(key, *addl_attrs):
return _UnboundLoad._from_keys(
- _UnboundLoad.undefer, (key, ) + addl_attrs, False, {})
+ _UnboundLoad.undefer, (key,) + addl_attrs, False, {}
+ )
@loader_option()
@@ -1405,10 +1445,7 @@ def undefer_group(loadopt, name):
"""
return loadopt.set_column_strategy(
- "*",
- None,
- {"undefer_group_%s" % name: True},
- opts_only=True
+ "*", None, {"undefer_group_%s" % name: True}, opts_only=True
)
@@ -1448,21 +1485,18 @@ def with_expression(loadopt, key, expression):
"""
- expression = sql_expr._labeled(
- _orm_full_deannotate(expression))
+ expression = sql_expr._labeled(_orm_full_deannotate(expression))
return loadopt.set_column_strategy(
- (key, ),
- {"query_expression": True},
- opts={"expression": expression}
+ (key,), {"query_expression": True}, opts={"expression": expression}
)
@with_expression._add_unbound_fn
def with_expression(key, expression):
return _UnboundLoad._from_keys(
- _UnboundLoad.with_expression, (key, ),
- False, {"expression": expression})
+ _UnboundLoad.with_expression, (key,), False, {"expression": expression}
+ )
@loader_option()
@@ -1483,7 +1517,11 @@ def selectin_polymorphic(loadopt, classes):
"""
loadopt.set_class_strategy(
{"selectinload_polymorphic": True},
- opts={"entities": tuple(sorted((inspect(cls) for cls in classes), key=id))}
+ opts={
+ "entities": tuple(
+ sorted((inspect(cls) for cls in classes), key=id)
+ )
+ },
)
return loadopt
@@ -1492,8 +1530,6 @@ def selectin_polymorphic(loadopt, classes):
def selectin_polymorphic(base_cls, classes):
ul = _UnboundLoad()
ul.is_class_strategy = True
- ul.path = (inspect(base_cls), )
- ul.selectin_polymorphic(
- classes
- )
+ ul.path = (inspect(base_cls),)
+ ul.selectin_polymorphic(classes)
return ul
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index 08a66a8db..0cd488cbd 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -13,8 +13,15 @@ between instances based on join conditions.
from . import exc, util as orm_util, attributes
-def populate(source, source_mapper, dest, dest_mapper,
- synchronize_pairs, uowcommit, flag_cascaded_pks):
+def populate(
+ source,
+ source_mapper,
+ dest,
+ dest_mapper,
+ synchronize_pairs,
+ uowcommit,
+ flag_cascaded_pks,
+):
source_dict = source.dict
dest_dict = dest.dict
@@ -22,8 +29,9 @@ def populate(source, source_mapper, dest, dest_mapper,
try:
# inline of source_mapper._get_state_attr_by_column
prop = source_mapper._columntoproperty[l]
- value = source.manager[prop.key].impl.get(source, source_dict,
- attributes.PASSIVE_OFF)
+ value = source.manager[prop.key].impl.get(
+ source, source_dict, attributes.PASSIVE_OFF
+ )
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
@@ -39,14 +47,16 @@ def populate(source, source_mapper, dest, dest_mapper,
# how often this logic is invoked for memory/performance
# reasons, since we only need this info for a primary key
# destination.
- if flag_cascaded_pks and l.primary_key and \
- r.primary_key and \
- r.references(l):
+ if (
+ flag_cascaded_pks
+ and l.primary_key
+ and r.primary_key
+ and r.references(l)
+ ):
uowcommit.attributes[("pk_cascaded", dest, r)] = True
-def bulk_populate_inherit_keys(
- source_dict, source_mapper, synchronize_pairs):
+def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs):
# a simplified version of populate() used by bulk insert mode
for l, r in synchronize_pairs:
try:
@@ -64,14 +74,15 @@ def bulk_populate_inherit_keys(
def clear(dest, dest_mapper, synchronize_pairs):
for l, r in synchronize_pairs:
- if r.primary_key and \
- dest_mapper._get_state_attr_by_column(
- dest, dest.dict, r) not in orm_util._none_set:
+ if (
+ r.primary_key
+ and dest_mapper._get_state_attr_by_column(dest, dest.dict, r)
+ not in orm_util._none_set
+ ):
raise AssertionError(
"Dependency rule tried to blank-out primary key "
- "column '%s' on instance '%s'" %
- (r, orm_util.state_str(dest))
+ "column '%s' on instance '%s'" % (r, orm_util.state_str(dest))
)
try:
dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None)
@@ -83,9 +94,11 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
for l, r in synchronize_pairs:
try:
oldvalue = source_mapper._get_committed_attr_by_column(
- source.obj(), l)
+ source.obj(), l
+ )
value = source_mapper._get_state_attr_by_column(
- source, source.dict, l, passive=attributes.PASSIVE_OFF)
+ source, source.dict, l, passive=attributes.PASSIVE_OFF
+ )
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
dest[r.key] = value
@@ -96,7 +109,8 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(
- source, source.dict, l, passive=attributes.PASSIVE_OFF)
+ source, source.dict, l, passive=attributes.PASSIVE_OFF
+ )
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
@@ -114,27 +128,31 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
history = uowcommit.get_attribute_history(
- source, prop.key, attributes.PASSIVE_NO_INITIALIZE)
+ source, prop.key, attributes.PASSIVE_NO_INITIALIZE
+ )
if bool(history.deleted):
return True
else:
return False
-def _raise_col_to_prop(isdest, source_mapper, source_column,
- dest_mapper, dest_column):
+def _raise_col_to_prop(
+ isdest, source_mapper, source_column, dest_mapper, dest_column
+):
if isdest:
raise exc.UnmappedColumnError(
"Can't execute sync rule for "
"destination column '%s'; mapper '%s' does not map "
"this column. Try using an explicit `foreign_keys` "
"collection which does not include this column (or use "
- "a viewonly=True relation)." % (dest_column, dest_mapper))
+ "a viewonly=True relation)." % (dest_column, dest_mapper)
+ )
else:
raise exc.UnmappedColumnError(
"Can't execute sync rule for "
"source column '%s'; mapper '%s' does not map this "
"column. Try using an explicit `foreign_keys` "
"collection which does not include destination column "
- "'%s' (or use a viewonly=True relation)." %
- (source_column, source_mapper, dest_column))
+ "'%s' (or use a viewonly=True relation)."
+ % (source_column, source_mapper, dest_column)
+ )
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index a83a99d78..545811bb4 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -41,9 +41,11 @@ def track_cascade_events(descriptor, prop):
prop = state.manager.mapper._props[key]
item_state = attributes.instance_state(item)
- if prop._cascade.save_update and \
- (prop.cascade_backrefs or key == initiator.key) and \
- not sess._contains_state(item_state):
+ if (
+ prop._cascade.save_update
+ and (prop.cascade_backrefs or key == initiator.key)
+ and not sess._contains_state(item_state)
+ ):
sess._save_or_update_state(item_state)
return item
@@ -59,12 +61,15 @@ def track_cascade_events(descriptor, prop):
sess._flush_warning(
"collection remove"
if prop.uselist
- else "related attribute delete")
+ else "related attribute delete"
+ )
- if item is not None and \
- item is not attributes.NEVER_SET and \
- item is not attributes.PASSIVE_NO_RESULT and \
- prop._cascade.delete_orphan:
+ if (
+ item is not None
+ and item is not attributes.NEVER_SET
+ and item is not attributes.PASSIVE_NO_RESULT
+ and prop._cascade.delete_orphan
+ ):
# expunge pending orphans
item_state = attributes.instance_state(item)
@@ -93,26 +98,31 @@ def track_cascade_events(descriptor, prop):
prop = state.manager.mapper._props[key]
if newvalue is not None:
newvalue_state = attributes.instance_state(newvalue)
- if prop._cascade.save_update and \
- (prop.cascade_backrefs or key == initiator.key) and \
- not sess._contains_state(newvalue_state):
+ if (
+ prop._cascade.save_update
+ and (prop.cascade_backrefs or key == initiator.key)
+ and not sess._contains_state(newvalue_state)
+ ):
sess._save_or_update_state(newvalue_state)
- if oldvalue is not None and \
- oldvalue is not attributes.NEVER_SET and \
- oldvalue is not attributes.PASSIVE_NO_RESULT and \
- prop._cascade.delete_orphan:
+ if (
+ oldvalue is not None
+ and oldvalue is not attributes.NEVER_SET
+ and oldvalue is not attributes.PASSIVE_NO_RESULT
+ and prop._cascade.delete_orphan
+ ):
# possible to reach here with attributes.NEVER_SET ?
oldvalue_state = attributes.instance_state(oldvalue)
- if oldvalue_state in sess._new and \
- prop.mapper._is_orphan(oldvalue_state):
+ if oldvalue_state in sess._new and prop.mapper._is_orphan(
+ oldvalue_state
+ ):
sess.expunge(oldvalue)
return newvalue
- event.listen(descriptor, 'append', append, raw=True, retval=True)
- event.listen(descriptor, 'remove', remove, raw=True, retval=True)
- event.listen(descriptor, 'set', set_, raw=True, retval=True)
+ event.listen(descriptor, "append", append, raw=True, retval=True)
+ event.listen(descriptor, "remove", remove, raw=True, retval=True)
+ event.listen(descriptor, "set", set_, raw=True, retval=True)
class UOWTransaction(object):
@@ -197,8 +207,9 @@ class UOWTransaction(object):
self.states[state] = (isdelete, True)
- def get_attribute_history(self, state, key,
- passive=attributes.PASSIVE_NO_INITIALIZE):
+ def get_attribute_history(
+ self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE
+ ):
"""facade to attributes.get_state_history(), including
caching of results."""
@@ -213,12 +224,16 @@ class UOWTransaction(object):
# if the cached lookup was "passive" and now
# we want non-passive, do a non-passive lookup and re-cache
- if not cached_passive & attributes.SQL_OK \
- and passive & attributes.SQL_OK:
+ if (
+ not cached_passive & attributes.SQL_OK
+ and passive & attributes.SQL_OK
+ ):
impl = state.manager[key].impl
- history = impl.get_history(state, state.dict,
- attributes.PASSIVE_OFF |
- attributes.LOAD_AGAINST_COMMITTED)
+ history = impl.get_history(
+ state,
+ state.dict,
+ attributes.PASSIVE_OFF | attributes.LOAD_AGAINST_COMMITTED,
+ )
if history and impl.uses_objects:
state_history = history.as_state()
else:
@@ -228,14 +243,14 @@ class UOWTransaction(object):
impl = state.manager[key].impl
# TODO: store the history as (state, object) tuples
# so we don't have to keep converting here
- history = impl.get_history(state, state.dict, passive |
- attributes.LOAD_AGAINST_COMMITTED)
+ history = impl.get_history(
+ state, state.dict, passive | attributes.LOAD_AGAINST_COMMITTED
+ )
if history and impl.uses_objects:
state_history = history.as_state()
else:
state_history = history
- self.attributes[hashkey] = (history, state_history,
- passive)
+ self.attributes[hashkey] = (history, state_history, passive)
return state_history
@@ -247,17 +262,25 @@ class UOWTransaction(object):
if key not in self.presort_actions:
self.presort_actions[key] = Preprocess(processor, fromparent)
- def register_object(self, state, isdelete=False,
- listonly=False, cancel_delete=False,
- operation=None, prop=None):
+ def register_object(
+ self,
+ state,
+ isdelete=False,
+ listonly=False,
+ cancel_delete=False,
+ operation=None,
+ prop=None,
+ ):
if not self.session._contains_state(state):
# this condition is normal when objects are registered
# as part of a relationship cascade operation. it should
# not occur for the top-level register from Session.flush().
if not state.deleted and operation is not None:
- util.warn("Object of type %s not in session, %s operation "
- "along '%s' will not proceed" %
- (orm_util.state_class_str(state), operation, prop))
+ util.warn(
+ "Object of type %s not in session, %s operation "
+ "along '%s' will not proceed"
+ % (orm_util.state_class_str(state), operation, prop)
+ )
return False
if state not in self.states:
@@ -340,24 +363,26 @@ class UOWTransaction(object):
# see if the graph of mapper dependencies has cycles.
self.cycles = cycles = topological.find_cycles(
- self.dependencies,
- list(self.postsort_actions.values()))
+ self.dependencies, list(self.postsort_actions.values())
+ )
if cycles:
# if yes, break the per-mapper actions into
# per-state actions
convert = dict(
- (rec, set(rec.per_state_flush_actions(self)))
- for rec in cycles
+ (rec, set(rec.per_state_flush_actions(self))) for rec in cycles
)
# rewrite the existing dependencies to point to
# the per-state actions for those per-mapper actions
# that were broken up.
for edge in list(self.dependencies):
- if None in edge or \
- edge[0].disabled or edge[1].disabled or \
- cycles.issuperset(edge):
+ if (
+ None in edge
+ or edge[0].disabled
+ or edge[1].disabled
+ or cycles.issuperset(edge)
+ ):
self.dependencies.remove(edge)
elif edge[0] in cycles:
self.dependencies.remove(edge)
@@ -368,10 +393,9 @@ class UOWTransaction(object):
for dep in convert[edge[1]]:
self.dependencies.add((edge[0], dep))
- return set([a for a in self.postsort_actions.values()
- if not a.disabled
- ]
- ).difference(cycles)
+ return set(
+ [a for a in self.postsort_actions.values() if not a.disabled]
+ ).difference(cycles)
def execute(self):
postsort_actions = self._generate_actions()
@@ -386,15 +410,13 @@ class UOWTransaction(object):
# execute
if self.cycles:
for set_ in topological.sort_as_subsets(
- self.dependencies,
- postsort_actions):
+ self.dependencies, postsort_actions
+ ):
while set_:
n = set_.pop()
n.execute_aggregate(self, set_)
else:
- for rec in topological.sort(
- self.dependencies,
- postsort_actions):
+ for rec in topological.sort(self.dependencies, postsort_actions):
rec.execute(self)
def finalize_flush_changes(self):
@@ -410,8 +432,7 @@ class UOWTransaction(object):
states = set(self.states)
isdel = set(
- s for (s, (isdelete, listonly)) in self.states.items()
- if isdelete
+ s for (s, (isdelete, listonly)) in self.states.items() if isdelete
)
other = states.difference(isdel)
if isdel:
@@ -424,8 +445,8 @@ class IterateMappersMixin(object):
def _mappers(self, uow):
if self.fromparent:
return iter(
- m for m in
- self.dependency_processor.parent.self_and_descendants
+ m
+ for m in self.dependency_processor.parent.self_and_descendants
if uow._mapper_for_dep[(m, self.dependency_processor)]
)
else:
@@ -434,8 +455,10 @@ class IterateMappersMixin(object):
class Preprocess(IterateMappersMixin):
__slots__ = (
- 'dependency_processor', 'fromparent', 'processed',
- 'setup_flush_actions'
+ "dependency_processor",
+ "fromparent",
+ "processed",
+ "setup_flush_actions",
)
def __init__(self, dependency_processor, fromparent):
@@ -464,12 +487,14 @@ class Preprocess(IterateMappersMixin):
self.dependency_processor.presort_saves(uow, save_states)
self.processed.update(save_states)
- if (delete_states or save_states):
+ if delete_states or save_states:
if not self.setup_flush_actions and (
- self.dependency_processor.
- prop_has_changes(uow, delete_states, True) or
- self.dependency_processor.
- prop_has_changes(uow, save_states, False)
+ self.dependency_processor.prop_has_changes(
+ uow, delete_states, True
+ )
+ or self.dependency_processor.prop_has_changes(
+ uow, save_states, False
+ )
):
self.dependency_processor.per_property_flush_actions(uow)
self.setup_flush_actions = True
@@ -479,16 +504,14 @@ class Preprocess(IterateMappersMixin):
class PostSortRec(object):
- __slots__ = 'disabled',
+ __slots__ = ("disabled",)
def __new__(cls, uow, *args):
- key = (cls, ) + args
+ key = (cls,) + args
if key in uow.postsort_actions:
return uow.postsort_actions[key]
else:
- uow.postsort_actions[key] = \
- ret = \
- object.__new__(cls)
+ uow.postsort_actions[key] = ret = object.__new__(cls)
ret.disabled = False
return ret
@@ -497,14 +520,15 @@ class PostSortRec(object):
class ProcessAll(IterateMappersMixin, PostSortRec):
- __slots__ = 'dependency_processor', 'isdelete', 'fromparent'
+ __slots__ = "dependency_processor", "isdelete", "fromparent"
def __init__(self, uow, dependency_processor, isdelete, fromparent):
self.dependency_processor = dependency_processor
self.isdelete = isdelete
self.fromparent = fromparent
- uow.deps[dependency_processor.parent.base_mapper].\
- add(dependency_processor)
+ uow.deps[dependency_processor.parent.base_mapper].add(
+ dependency_processor
+ )
def execute(self, uow):
states = self._elements(uow)
@@ -524,7 +548,7 @@ class ProcessAll(IterateMappersMixin, PostSortRec):
return "%s(%s, isdelete=%s)" % (
self.__class__.__name__,
self.dependency_processor,
- self.isdelete
+ self.isdelete,
)
def _elements(self, uow):
@@ -536,7 +560,7 @@ class ProcessAll(IterateMappersMixin, PostSortRec):
class PostUpdateAll(PostSortRec):
- __slots__ = 'mapper', 'isdelete'
+ __slots__ = "mapper", "isdelete"
def __init__(self, uow, mapper, isdelete):
self.mapper = mapper
@@ -550,22 +574,23 @@ class PostUpdateAll(PostSortRec):
class SaveUpdateAll(PostSortRec):
- __slots__ = 'mapper',
+ __slots__ = ("mapper",)
def __init__(self, uow, mapper):
self.mapper = mapper
assert mapper is mapper.base_mapper
def execute(self, uow):
- persistence.save_obj(self.mapper,
- uow.states_for_mapper_hierarchy(
- self.mapper, False, False),
- uow
- )
+ persistence.save_obj(
+ self.mapper,
+ uow.states_for_mapper_hierarchy(self.mapper, False, False),
+ uow,
+ )
def per_state_flush_actions(self, uow):
- states = list(uow.states_for_mapper_hierarchy(
- self.mapper, False, False))
+ states = list(
+ uow.states_for_mapper_hierarchy(self.mapper, False, False)
+ )
base_mapper = self.mapper.base_mapper
delete_all = DeleteAll(uow, base_mapper)
for state in states:
@@ -580,29 +605,27 @@ class SaveUpdateAll(PostSortRec):
dep.per_state_flush_actions(uow, states_for_prop, False)
def __repr__(self):
- return "%s(%s)" % (
- self.__class__.__name__,
- self.mapper
- )
+ return "%s(%s)" % (self.__class__.__name__, self.mapper)
class DeleteAll(PostSortRec):
- __slots__ = 'mapper',
+ __slots__ = ("mapper",)
def __init__(self, uow, mapper):
self.mapper = mapper
assert mapper is mapper.base_mapper
def execute(self, uow):
- persistence.delete_obj(self.mapper,
- uow.states_for_mapper_hierarchy(
- self.mapper, True, False),
- uow
- )
+ persistence.delete_obj(
+ self.mapper,
+ uow.states_for_mapper_hierarchy(self.mapper, True, False),
+ uow,
+ )
def per_state_flush_actions(self, uow):
- states = list(uow.states_for_mapper_hierarchy(
- self.mapper, True, False))
+ states = list(
+ uow.states_for_mapper_hierarchy(self.mapper, True, False)
+ )
base_mapper = self.mapper.base_mapper
save_all = SaveUpdateAll(uow, base_mapper)
for state in states:
@@ -617,14 +640,11 @@ class DeleteAll(PostSortRec):
dep.per_state_flush_actions(uow, states_for_prop, True)
def __repr__(self):
- return "%s(%s)" % (
- self.__class__.__name__,
- self.mapper
- )
+ return "%s(%s)" % (self.__class__.__name__, self.mapper)
class ProcessState(PostSortRec):
- __slots__ = 'dependency_processor', 'isdelete', 'state'
+ __slots__ = "dependency_processor", "isdelete", "state"
def __init__(self, uow, dependency_processor, isdelete, state):
self.dependency_processor = dependency_processor
@@ -635,10 +655,13 @@ class ProcessState(PostSortRec):
cls_ = self.__class__
dependency_processor = self.dependency_processor
isdelete = self.isdelete
- our_recs = [r for r in recs
- if r.__class__ is cls_ and
- r.dependency_processor is dependency_processor and
- r.isdelete is isdelete]
+ our_recs = [
+ r
+ for r in recs
+ if r.__class__ is cls_
+ and r.dependency_processor is dependency_processor
+ and r.isdelete is isdelete
+ ]
recs.difference_update(our_recs)
states = [self.state] + [r.state for r in our_recs]
if isdelete:
@@ -651,12 +674,12 @@ class ProcessState(PostSortRec):
self.__class__.__name__,
self.dependency_processor,
orm_util.state_str(self.state),
- self.isdelete
+ self.isdelete,
)
class SaveUpdateState(PostSortRec):
- __slots__ = 'state', 'mapper'
+ __slots__ = "state", "mapper"
def __init__(self, uow, state):
self.state = state
@@ -665,24 +688,23 @@ class SaveUpdateState(PostSortRec):
def execute_aggregate(self, uow, recs):
cls_ = self.__class__
mapper = self.mapper
- our_recs = [r for r in recs
- if r.__class__ is cls_ and
- r.mapper is mapper]
+ our_recs = [
+ r for r in recs if r.__class__ is cls_ and r.mapper is mapper
+ ]
recs.difference_update(our_recs)
- persistence.save_obj(mapper,
- [self.state] +
- [r.state for r in our_recs],
- uow)
+ persistence.save_obj(
+ mapper, [self.state] + [r.state for r in our_recs], uow
+ )
def __repr__(self):
return "%s(%s)" % (
self.__class__.__name__,
- orm_util.state_str(self.state)
+ orm_util.state_str(self.state),
)
class DeleteState(PostSortRec):
- __slots__ = 'state', 'mapper'
+ __slots__ = "state", "mapper"
def __init__(self, uow, state):
self.state = state
@@ -691,17 +713,17 @@ class DeleteState(PostSortRec):
def execute_aggregate(self, uow, recs):
cls_ = self.__class__
mapper = self.mapper
- our_recs = [r for r in recs
- if r.__class__ is cls_ and
- r.mapper is mapper]
+ our_recs = [
+ r for r in recs if r.__class__ is cls_ and r.mapper is mapper
+ ]
recs.difference_update(our_recs)
states = [self.state] + [r.state for r in our_recs]
- persistence.delete_obj(mapper,
- [s for s in states if uow.states[s][0]],
- uow)
+ persistence.delete_obj(
+ mapper, [s for s in states if uow.states[s][0]], uow
+ )
def __repr__(self):
return "%s(%s)" % (
self.__class__.__name__,
- orm_util.state_str(self.state)
+ orm_util.state_str(self.state),
)
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 43709a58c..a1b0cd5da 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -12,27 +12,51 @@ from .interfaces import PropComparator, MapperProperty
from . import attributes
import re
-from .base import instance_str, state_str, state_class_str, attribute_str, \
- state_attribute_str, object_mapper, object_state, _none_set, _never_set
+from .base import (
+ instance_str,
+ state_str,
+ state_class_str,
+ attribute_str,
+ state_attribute_str,
+ object_mapper,
+ object_state,
+ _none_set,
+ _never_set,
+)
from .base import class_mapper, _class_to_mapper
from .base import InspectionAttr
from .path_registry import PathRegistry
-all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
- "expunge", "save-update", "refresh-expire",
- "none"))
+all_cascades = frozenset(
+ (
+ "delete",
+ "delete-orphan",
+ "all",
+ "merge",
+ "expunge",
+ "save-update",
+ "refresh-expire",
+ "none",
+ )
+)
class CascadeOptions(frozenset):
"""Keeps track of the options sent to relationship().cascade"""
- _add_w_all_cascades = all_cascades.difference([
- 'all', 'none', 'delete-orphan'])
+ _add_w_all_cascades = all_cascades.difference(
+ ["all", "none", "delete-orphan"]
+ )
_allowed_cascades = all_cascades
__slots__ = (
- 'save_update', 'delete', 'refresh_expire', 'merge',
- 'expunge', 'delete_orphan')
+ "save_update",
+ "delete",
+ "refresh_expire",
+ "merge",
+ "expunge",
+ "delete_orphan",
+ )
def __new__(cls, value_list):
if isinstance(value_list, util.string_types) or value_list is None:
@@ -40,60 +64,62 @@ class CascadeOptions(frozenset):
values = set(value_list)
if values.difference(cls._allowed_cascades):
raise sa_exc.ArgumentError(
- "Invalid cascade option(s): %s" %
- ", ".join([repr(x) for x in
- sorted(values.difference(cls._allowed_cascades))]))
+ "Invalid cascade option(s): %s"
+ % ", ".join(
+ [
+ repr(x)
+ for x in sorted(
+ values.difference(cls._allowed_cascades)
+ )
+ ]
+ )
+ )
if "all" in values:
values.update(cls._add_w_all_cascades)
if "none" in values:
values.clear()
- values.discard('all')
+ values.discard("all")
self = frozenset.__new__(CascadeOptions, values)
- self.save_update = 'save-update' in values
- self.delete = 'delete' in values
- self.refresh_expire = 'refresh-expire' in values
- self.merge = 'merge' in values
- self.expunge = 'expunge' in values
+ self.save_update = "save-update" in values
+ self.delete = "delete" in values
+ self.refresh_expire = "refresh-expire" in values
+ self.merge = "merge" in values
+ self.expunge = "expunge" in values
self.delete_orphan = "delete-orphan" in values
if self.delete_orphan and not self.delete:
- util.warn("The 'delete-orphan' cascade "
- "option requires 'delete'.")
+ util.warn(
+ "The 'delete-orphan' cascade " "option requires 'delete'."
+ )
return self
def __repr__(self):
- return "CascadeOptions(%r)" % (
- ",".join([x for x in sorted(self)])
- )
+ return "CascadeOptions(%r)" % (",".join([x for x in sorted(self)]))
@classmethod
def from_string(cls, arg):
- values = [
- c for c
- in re.split(r'\s*,\s*', arg or "")
- if c
- ]
+ values = [c for c in re.split(r"\s*,\s*", arg or "") if c]
return cls(values)
-def _validator_events(
- desc, key, validator, include_removes, include_backrefs):
+def _validator_events(desc, key, validator, include_removes, include_backrefs):
"""Runs a validation method on an attribute value to be set or
appended.
"""
if not include_backrefs:
+
def detect_is_backref(state, initiator):
impl = state.manager[key].impl
return initiator.impl is not impl
if include_removes:
+
def append(state, value, initiator):
- if (
- initiator.op is not attributes.OP_BULK_REPLACE and
- (include_backrefs or not detect_is_backref(state, initiator))
+ if initiator.op is not attributes.OP_BULK_REPLACE and (
+ include_backrefs or not detect_is_backref(state, initiator)
):
return validator(state.obj(), key, value, False)
else:
@@ -103,7 +129,8 @@ def _validator_events(
if include_backrefs or not detect_is_backref(state, initiator):
obj = state.obj()
values[:] = [
- validator(obj, key, value, False) for value in values]
+ validator(obj, key, value, False) for value in values
+ ]
def set_(state, value, oldvalue, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
@@ -116,10 +143,10 @@ def _validator_events(
validator(state.obj(), key, value, True)
else:
+
def append(state, value, initiator):
- if (
- initiator.op is not attributes.OP_BULK_REPLACE and
- (include_backrefs or not detect_is_backref(state, initiator))
+ if initiator.op is not attributes.OP_BULK_REPLACE and (
+ include_backrefs or not detect_is_backref(state, initiator)
):
return validator(state.obj(), key, value)
else:
@@ -128,8 +155,7 @@ def _validator_events(
def bulk_set(state, values, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
obj = state.obj()
- values[:] = [
- validator(obj, key, value) for value in values]
+ values[:] = [validator(obj, key, value) for value in values]
def set_(state, value, oldvalue, initiator):
if include_backrefs or not detect_is_backref(state, initiator):
@@ -137,15 +163,16 @@ def _validator_events(
else:
return value
- event.listen(desc, 'append', append, raw=True, retval=True)
- event.listen(desc, 'bulk_replace', bulk_set, raw=True)
- event.listen(desc, 'set', set_, raw=True, retval=True)
+ event.listen(desc, "append", append, raw=True, retval=True)
+ event.listen(desc, "bulk_replace", bulk_set, raw=True)
+ event.listen(desc, "set", set_, raw=True, retval=True)
if include_removes:
event.listen(desc, "remove", remove, raw=True, retval=True)
-def polymorphic_union(table_map, typecolname,
- aliasname='p_union', cast_nulls=True):
+def polymorphic_union(
+ table_map, typecolname, aliasname="p_union", cast_nulls=True
+):
"""Create a ``UNION`` statement used by a polymorphic mapper.
See :ref:`concrete_inheritance` for an example of how
@@ -197,14 +224,22 @@ def polymorphic_union(table_map, typecolname,
for type, table in table_map.items():
if typecolname is not None:
result.append(
- sql.select([col(name, table) for name in colnames] +
- [sql.literal_column(
- sql_util._quote_ddl_expr(type)).
- label(typecolname)],
- from_obj=[table]))
+ sql.select(
+ [col(name, table) for name in colnames]
+ + [
+ sql.literal_column(
+ sql_util._quote_ddl_expr(type)
+ ).label(typecolname)
+ ],
+ from_obj=[table],
+ )
+ )
else:
- result.append(sql.select([col(name, table) for name in colnames],
- from_obj=[table]))
+ result.append(
+ sql.select(
+ [col(name, table) for name in colnames], from_obj=[table]
+ )
+ )
return sql.union_all(*result).alias(aliasname)
@@ -284,25 +319,29 @@ first()
class_, ident = args
else:
raise sa_exc.ArgumentError(
- "expected up to three positional arguments, "
- "got %s" % largs)
+ "expected up to three positional arguments, " "got %s" % largs
+ )
identity_token = kwargs.pop("identity_token", None)
if kwargs:
- raise sa_exc.ArgumentError("unknown keyword arguments: %s"
- % ", ".join(kwargs))
+ raise sa_exc.ArgumentError(
+ "unknown keyword arguments: %s" % ", ".join(kwargs)
+ )
mapper = class_mapper(class_)
if row is None:
return mapper.identity_key_from_primary_key(
- util.to_list(ident), identity_token=identity_token)
+ util.to_list(ident), identity_token=identity_token
+ )
else:
return mapper.identity_key_from_row(
- row, identity_token=identity_token)
+ row, identity_token=identity_token
+ )
else:
instance = kwargs.pop("instance")
if kwargs:
- raise sa_exc.ArgumentError("unknown keyword arguments: %s"
- % ", ".join(kwargs.keys))
+ raise sa_exc.ArgumentError(
+ "unknown keyword arguments: %s" % ", ".join(kwargs.keys)
+ )
mapper = object_mapper(instance)
return mapper.identity_key_from_instance(instance)
@@ -313,9 +352,15 @@ class ORMAdapter(sql_util.ColumnAdapter):
"""
- def __init__(self, entity, equivalents=None, adapt_required=False,
- chain_to=None, allow_label_resolve=True,
- anonymize_labels=False):
+ def __init__(
+ self,
+ entity,
+ equivalents=None,
+ adapt_required=False,
+ chain_to=None,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ ):
info = inspection.inspect(entity)
self.mapper = info.mapper
@@ -327,15 +372,18 @@ class ORMAdapter(sql_util.ColumnAdapter):
self.aliased_class = None
sql_util.ColumnAdapter.__init__(
- self, selectable, equivalents, chain_to,
+ self,
+ selectable,
+ equivalents,
+ chain_to,
adapt_required=adapt_required,
allow_label_resolve=allow_label_resolve,
anonymize_labels=anonymize_labels,
- include_fn=self._include_fn
+ include_fn=self._include_fn,
)
def _include_fn(self, elem):
- entity = elem._annotations.get('parentmapper', None)
+ entity = elem._annotations.get("parentmapper", None)
return not entity or entity.isa(self.mapper)
@@ -380,20 +428,25 @@ class AliasedClass(object):
"""
- def __init__(self, cls, alias=None,
- name=None,
- flat=False,
- adapt_on_names=False,
- # TODO: None for default here?
- with_polymorphic_mappers=(),
- with_polymorphic_discriminator=None,
- base_alias=None,
- use_mapper_path=False,
- represents_outer_join=False):
+ def __init__(
+ self,
+ cls,
+ alias=None,
+ name=None,
+ flat=False,
+ adapt_on_names=False,
+ # TODO: None for default here?
+ with_polymorphic_mappers=(),
+ with_polymorphic_discriminator=None,
+ base_alias=None,
+ use_mapper_path=False,
+ represents_outer_join=False,
+ ):
mapper = _class_to_mapper(cls)
if alias is None:
alias = mapper._with_polymorphic_selectable.alias(
- name=name, flat=flat)
+ name=name, flat=flat
+ )
self._aliased_insp = AliasedInsp(
self,
@@ -409,14 +462,14 @@ class AliasedClass(object):
base_alias,
use_mapper_path,
adapt_on_names,
- represents_outer_join
+ represents_outer_join,
)
- self.__name__ = 'AliasedClass_%s' % mapper.class_.__name__
+ self.__name__ = "AliasedClass_%s" % mapper.class_.__name__
def __getattr__(self, key):
try:
- _aliased_insp = self.__dict__['_aliased_insp']
+ _aliased_insp = self.__dict__["_aliased_insp"]
except KeyError:
raise AttributeError()
else:
@@ -434,13 +487,13 @@ class AliasedClass(object):
ret = attr.adapt_to_entity(_aliased_insp)
setattr(self, key, ret)
return ret
- elif hasattr(attr, 'func_code'):
+ elif hasattr(attr, "func_code"):
is_method = getattr(_aliased_insp._target, key, None)
if is_method and is_method.__self__ is not None:
return util.types.MethodType(attr.__func__, self, self)
else:
return None
- elif hasattr(attr, '__get__'):
+ elif hasattr(attr, "__get__"):
ret = attr.__get__(None, self)
if isinstance(ret, PropComparator):
return ret.adapt_to_entity(_aliased_insp)
@@ -450,8 +503,10 @@ class AliasedClass(object):
return attr
def __repr__(self):
- return '<AliasedClass at 0x%x; %s>' % (
- id(self), self._aliased_insp._target.__name__)
+ return "<AliasedClass at 0x%x; %s>" % (
+ id(self),
+ self._aliased_insp._target.__name__,
+ )
class AliasedInsp(InspectionAttr):
@@ -490,10 +545,19 @@ class AliasedInsp(InspectionAttr):
"""
- def __init__(self, entity, mapper, selectable, name,
- with_polymorphic_mappers, polymorphic_on,
- _base_alias, _use_mapper_path, adapt_on_names,
- represents_outer_join):
+ def __init__(
+ self,
+ entity,
+ mapper,
+ selectable,
+ name,
+ with_polymorphic_mappers,
+ polymorphic_on,
+ _base_alias,
+ _use_mapper_path,
+ adapt_on_names,
+ represents_outer_join,
+ ):
self.entity = entity
self.mapper = mapper
self.selectable = selectable
@@ -505,18 +569,28 @@ class AliasedInsp(InspectionAttr):
self.represents_outer_join = represents_outer_join
self._adapter = sql_util.ColumnAdapter(
- selectable, equivalents=mapper._equivalent_columns,
- adapt_on_names=adapt_on_names, anonymize_labels=True)
+ selectable,
+ equivalents=mapper._equivalent_columns,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=True,
+ )
self._adapt_on_names = adapt_on_names
self._target = mapper.class_
for poly in self.with_polymorphic_mappers:
if poly is not mapper:
- setattr(self.entity, poly.class_.__name__,
- AliasedClass(poly.class_, selectable, base_alias=self,
- adapt_on_names=adapt_on_names,
- use_mapper_path=_use_mapper_path))
+ setattr(
+ self.entity,
+ poly.class_.__name__,
+ AliasedClass(
+ poly.class_,
+ selectable,
+ base_alias=self,
+ adapt_on_names=adapt_on_names,
+ use_mapper_path=_use_mapper_path,
+ ),
+ )
is_aliased_class = True
"always returns True"
@@ -536,39 +610,35 @@ class AliasedInsp(InspectionAttr):
def __getstate__(self):
return {
- 'entity': self.entity,
- 'mapper': self.mapper,
- 'alias': self.selectable,
- 'name': self.name,
- 'adapt_on_names': self._adapt_on_names,
- 'with_polymorphic_mappers':
- self.with_polymorphic_mappers,
- 'with_polymorphic_discriminator':
- self.polymorphic_on,
- 'base_alias': self._base_alias,
- 'use_mapper_path': self._use_mapper_path,
- 'represents_outer_join': self.represents_outer_join
+ "entity": self.entity,
+ "mapper": self.mapper,
+ "alias": self.selectable,
+ "name": self.name,
+ "adapt_on_names": self._adapt_on_names,
+ "with_polymorphic_mappers": self.with_polymorphic_mappers,
+ "with_polymorphic_discriminator": self.polymorphic_on,
+ "base_alias": self._base_alias,
+ "use_mapper_path": self._use_mapper_path,
+ "represents_outer_join": self.represents_outer_join,
}
def __setstate__(self, state):
self.__init__(
- state['entity'],
- state['mapper'],
- state['alias'],
- state['name'],
- state['with_polymorphic_mappers'],
- state['with_polymorphic_discriminator'],
- state['base_alias'],
- state['use_mapper_path'],
- state['adapt_on_names'],
- state['represents_outer_join']
+ state["entity"],
+ state["mapper"],
+ state["alias"],
+ state["name"],
+ state["with_polymorphic_mappers"],
+ state["with_polymorphic_discriminator"],
+ state["base_alias"],
+ state["use_mapper_path"],
+ state["adapt_on_names"],
+ state["represents_outer_join"],
)
def _adapt_element(self, elem):
- return self._adapter.traverse(elem).\
- _annotate({
- 'parententity': self,
- 'parentmapper': self.mapper}
+ return self._adapter.traverse(elem)._annotate(
+ {"parententity": self, "parentmapper": self.mapper}
)
def _entity_for_mapper(self, mapper):
@@ -578,12 +648,12 @@ class AliasedInsp(InspectionAttr):
return self
else:
return getattr(
- self.entity, mapper.class_.__name__)._aliased_insp
+ self.entity, mapper.class_.__name__
+ )._aliased_insp
elif mapper.isa(self.mapper):
return self
else:
- assert False, "mapper %s doesn't correspond to %s" % (
- mapper, self)
+ assert False, "mapper %s doesn't correspond to %s" % (mapper, self)
@util.memoized_property
def _memoized_values(self):
@@ -599,11 +669,15 @@ class AliasedInsp(InspectionAttr):
def __repr__(self):
if self.with_polymorphic_mappers:
with_poly = "(%s)" % ", ".join(
- mp.class_.__name__ for mp in self.with_polymorphic_mappers)
+ mp.class_.__name__ for mp in self.with_polymorphic_mappers
+ )
else:
with_poly = ""
- return '<AliasedInsp at 0x%x; %s%s>' % (
- id(self), self.class_.__name__, with_poly)
+ return "<AliasedInsp at 0x%x; %s%s>" % (
+ id(self),
+ self.class_.__name__,
+ with_poly,
+ )
inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
@@ -700,15 +774,26 @@ def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False):
)
return element.alias(name, flat=flat)
else:
- return AliasedClass(element, alias=alias, flat=flat,
- name=name, adapt_on_names=adapt_on_names)
+ return AliasedClass(
+ element,
+ alias=alias,
+ flat=flat,
+ name=name,
+ adapt_on_names=adapt_on_names,
+ )
-def with_polymorphic(base, classes, selectable=False,
- flat=False,
- polymorphic_on=None, aliased=False,
- innerjoin=False, _use_mapper_path=False,
- _existing_alias=None):
+def with_polymorphic(
+ base,
+ classes,
+ selectable=False,
+ flat=False,
+ polymorphic_on=None,
+ aliased=False,
+ innerjoin=False,
+ _use_mapper_path=False,
+ _existing_alias=None,
+):
"""Produce an :class:`.AliasedClass` construct which specifies
columns for descendant mappers of the given base.
@@ -777,24 +862,26 @@ def with_polymorphic(base, classes, selectable=False,
if _existing_alias:
assert _existing_alias.mapper is primary_mapper
classes = util.to_set(classes)
- new_classes = set([
- mp.class_ for mp in
- _existing_alias.with_polymorphic_mappers])
+ new_classes = set(
+ [mp.class_ for mp in _existing_alias.with_polymorphic_mappers]
+ )
if classes == new_classes:
return _existing_alias
else:
classes = classes.union(new_classes)
- mappers, selectable = primary_mapper.\
- _with_polymorphic_args(classes, selectable,
- innerjoin=innerjoin)
+ mappers, selectable = primary_mapper._with_polymorphic_args(
+ classes, selectable, innerjoin=innerjoin
+ )
if aliased or flat:
selectable = selectable.alias(flat=flat)
- return AliasedClass(base,
- selectable,
- with_polymorphic_mappers=mappers,
- with_polymorphic_discriminator=polymorphic_on,
- use_mapper_path=_use_mapper_path,
- represents_outer_join=not innerjoin)
+ return AliasedClass(
+ base,
+ selectable,
+ with_polymorphic_mappers=mappers,
+ with_polymorphic_discriminator=polymorphic_on,
+ use_mapper_path=_use_mapper_path,
+ represents_outer_join=not innerjoin,
+ )
def _orm_annotate(element, exclude=None):
@@ -804,7 +891,7 @@ def _orm_annotate(element, exclude=None):
Elements within the exclude collection will be cloned but not annotated.
"""
- return sql_util._deep_annotate(element, {'_orm_adapt': True}, exclude)
+ return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
def _orm_deannotate(element):
@@ -816,9 +903,9 @@ def _orm_deannotate(element):
"""
- return sql_util._deep_deannotate(element,
- values=("_orm_adapt", "parententity")
- )
+ return sql_util._deep_deannotate(
+ element, values=("_orm_adapt", "parententity")
+ )
def _orm_full_deannotate(element):
@@ -831,12 +918,18 @@ class _ORMJoin(expression.Join):
__visit_name__ = expression.Join.__visit_name__
def __init__(
- self,
- left, right, onclause=None, isouter=False,
- full=False, _left_memo=None, _right_memo=None):
+ self,
+ left,
+ right,
+ onclause=None,
+ isouter=False,
+ full=False,
+ _left_memo=None,
+ _right_memo=None,
+ ):
left_info = inspection.inspect(left)
- left_orm_info = getattr(left, '_joined_from_info', left_info)
+ left_orm_info = getattr(left, "_joined_from_info", left_info)
right_info = inspection.inspect(right)
adapt_to = right_info.selectable
@@ -859,19 +952,18 @@ class _ORMJoin(expression.Join):
prop = None
if prop:
- if sql_util.clause_is_present(
- on_selectable, left_info.selectable):
+ if sql_util.clause_is_present(on_selectable, left_info.selectable):
adapt_from = on_selectable
else:
adapt_from = left_info.selectable
- pj, sj, source, dest, \
- secondary, target_adapter = prop._create_joins(
- source_selectable=adapt_from,
- dest_selectable=adapt_to,
- source_polymorphic=True,
- dest_polymorphic=True,
- of_type=right_info.mapper)
+ pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
+ source_selectable=adapt_from,
+ dest_selectable=adapt_to,
+ source_polymorphic=True,
+ dest_polymorphic=True,
+ of_type=right_info.mapper,
+ )
if sj is not None:
if isouter:
@@ -887,8 +979,11 @@ class _ORMJoin(expression.Join):
expression.Join.__init__(self, left, right, onclause, isouter, full)
- if not prop and getattr(right_info, 'mapper', None) \
- and right_info.mapper.single:
+ if (
+ not prop
+ and getattr(right_info, "mapper", None)
+ and right_info.mapper.single
+ ):
# if single inheritance target and we are using a manual
# or implicit ON clause, augment it the same way we'd augment the
# WHERE.
@@ -911,33 +1006,39 @@ class _ORMJoin(expression.Join):
assert self.right is leftmost
left = _ORMJoin(
- self.left, other.left,
- self.onclause, isouter=self.isouter,
+ self.left,
+ other.left,
+ self.onclause,
+ isouter=self.isouter,
_left_memo=self._left_memo,
- _right_memo=other._left_memo
+ _right_memo=other._left_memo,
)
return _ORMJoin(
left,
other.right,
- other.onclause, isouter=other.isouter,
- _right_memo=other._right_memo
+ other.onclause,
+ isouter=other.isouter,
+ _right_memo=other._right_memo,
)
def join(
- self, right, onclause=None,
- isouter=False, full=False, join_to_left=None):
+ self,
+ right,
+ onclause=None,
+ isouter=False,
+ full=False,
+ join_to_left=None,
+ ):
return _ORMJoin(self, right, onclause, full, isouter)
- def outerjoin(
- self, right, onclause=None,
- full=False, join_to_left=None):
+ def outerjoin(self, right, onclause=None, full=False, join_to_left=None):
return _ORMJoin(self, right, onclause, True, full=full)
def join(
- left, right, onclause=None, isouter=False,
- full=False, join_to_left=None):
+ left, right, onclause=None, isouter=False, full=False, join_to_left=None
+):
r"""Produce an inner join between left and right clauses.
:func:`.orm.join` is an extension to the core join interface
@@ -1085,8 +1186,9 @@ def _entity_isa(given, mapper):
"""
if given.is_aliased_class:
- return mapper in given.with_polymorphic_mappers or \
- given.mapper.isa(mapper)
+ return mapper in given.with_polymorphic_mappers or given.mapper.isa(
+ mapper
+ )
elif given.with_polymorphic_mappers:
return mapper in given.with_polymorphic_mappers
else:
@@ -1126,5 +1228,7 @@ def randomize_unitofwork():
from sqlalchemy.orm import unitofwork, session, mapper, dependency
from sqlalchemy.util import topological
from sqlalchemy.testing.util import RandomSet
- topological.set = unitofwork.set = session.set = mapper.set = \
- dependency.set = RandomSet
+
+ topological.set = (
+ unitofwork.set
+ ) = session.set = mapper.set = dependency.set = RandomSet
diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py
index f2f035051..2aa6eeeb7 100644
--- a/lib/sqlalchemy/pool/__init__.py
+++ b/lib/sqlalchemy/pool/__init__.py
@@ -20,7 +20,12 @@ SQLAlchemy connection pool.
from .base import _refs # noqa
from .base import Pool # noqa
from .impl import ( # noqa
- QueuePool, StaticPool, NullPool, AssertionPool, SingletonThreadPool)
+ QueuePool,
+ StaticPool,
+ NullPool,
+ AssertionPool,
+ SingletonThreadPool,
+)
from .dbapi_proxy import manage, clear_managers # noqa
from .base import reset_rollback, reset_commit, reset_none # noqa
diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py
index 442d3b64a..382e740c6 100644
--- a/lib/sqlalchemy/pool/base.py
+++ b/lib/sqlalchemy/pool/base.py
@@ -18,9 +18,9 @@ from .. import exc, log, event, interfaces, util
from ..util import threading
-reset_rollback = util.symbol('reset_rollback')
-reset_commit = util.symbol('reset_commit')
-reset_none = util.symbol('reset_none')
+reset_rollback = util.symbol("reset_rollback")
+reset_commit = util.symbol("reset_commit")
+reset_none = util.symbol("reset_none")
class _ConnDialect(object):
@@ -46,7 +46,8 @@ class _ConnDialect(object):
def do_ping(self, dbapi_connection):
raise NotImplementedError(
"The ping feature requires that a dialect is "
- "passed to the connection pool.")
+ "passed to the connection pool."
+ )
class Pool(log.Identified):
@@ -55,16 +56,20 @@ class Pool(log.Identified):
_dialect = _ConnDialect()
- def __init__(self,
- creator, recycle=-1, echo=None,
- use_threadlocal=False,
- logging_name=None,
- reset_on_return=True,
- listeners=None,
- events=None,
- dialect=None,
- pre_ping=False,
- _dispatch=None):
+ def __init__(
+ self,
+ creator,
+ recycle=-1,
+ echo=None,
+ use_threadlocal=False,
+ logging_name=None,
+ reset_on_return=True,
+ listeners=None,
+ events=None,
+ dialect=None,
+ pre_ping=False,
+ _dispatch=None,
+ ):
"""
Construct a Pool.
@@ -200,16 +205,16 @@ class Pool(log.Identified):
self._invalidate_time = 0
self._use_threadlocal = use_threadlocal
self._pre_ping = pre_ping
- if reset_on_return in ('rollback', True, reset_rollback):
+ if reset_on_return in ("rollback", True, reset_rollback):
self._reset_on_return = reset_rollback
- elif reset_on_return in ('none', None, False, reset_none):
+ elif reset_on_return in ("none", None, False, reset_none):
self._reset_on_return = reset_none
- elif reset_on_return in ('commit', reset_commit):
+ elif reset_on_return in ("commit", reset_commit):
self._reset_on_return = reset_commit
else:
raise exc.ArgumentError(
- "Invalid value for 'reset_on_return': %r"
- % reset_on_return)
+ "Invalid value for 'reset_on_return': %r" % reset_on_return
+ )
self.echo = echo
@@ -223,17 +228,18 @@ class Pool(log.Identified):
if listeners:
util.warn_deprecated(
"The 'listeners' argument to Pool (and "
- "create_engine()) is deprecated. Use event.listen().")
+ "create_engine()) is deprecated. Use event.listen()."
+ )
for l in listeners:
self.add_listener(l)
@property
def _creator(self):
- return self.__dict__['_creator']
+ return self.__dict__["_creator"]
@_creator.setter
def _creator(self, creator):
- self.__dict__['_creator'] = creator
+ self.__dict__["_creator"] = creator
self._invoke_creator = self._should_wrap_creator(creator)
def _should_wrap_creator(self, creator):
@@ -252,7 +258,7 @@ class Pool(log.Identified):
# look for the exact arg signature that DefaultStrategy
# sends us
- if (argspec[0], argspec[3]) == (['connection_record'], (None,)):
+ if (argspec[0], argspec[3]) == (["connection_record"], (None,)):
return creator
# or just a single positional
elif positionals == 1:
@@ -268,11 +274,13 @@ class Pool(log.Identified):
try:
self._dialect.do_close(connection)
except Exception:
- self.logger.error("Exception closing connection %r",
- connection, exc_info=True)
+ self.logger.error(
+ "Exception closing connection %r", connection, exc_info=True
+ )
@util.deprecated(
- 2.7, "Pool.add_listener is deprecated. Use event.listen()")
+ 2.7, "Pool.add_listener is deprecated. Use event.listen()"
+ )
def add_listener(self, listener):
"""Add a :class:`.PoolListener`-like object to this pool.
@@ -315,7 +323,7 @@ class Pool(log.Identified):
rec = getattr(connection, "_connection_record", None)
if not rec or self._invalidate_time < rec.starttime:
self._invalidate_time = time.time()
- if _checkin and getattr(connection, 'is_valid', False):
+ if _checkin and getattr(connection, "is_valid", False):
connection.invalidate(exception)
def recreate(self):
@@ -491,15 +499,14 @@ class _ConnectionRecord(object):
fairy = _ConnectionFairy(dbapi_connection, rec, echo)
rec.fairy_ref = weakref.ref(
fairy,
- lambda ref: _finalize_fairy and
- _finalize_fairy(
- None,
- rec, pool, ref, echo)
+ lambda ref: _finalize_fairy
+ and _finalize_fairy(None, rec, pool, ref, echo),
)
_refs.add(rec)
if echo:
- pool.logger.debug("Connection %r checked out from pool",
- dbapi_connection)
+ pool.logger.debug(
+ "Connection %r checked out from pool", dbapi_connection
+ )
return fairy
def _checkin_failed(self, err):
@@ -563,12 +570,16 @@ class _ConnectionRecord(object):
self.__pool.logger.info(
"%sInvalidate connection %r (reason: %s:%s)",
"Soft " if soft else "",
- self.connection, e.__class__.__name__, e)
+ self.connection,
+ e.__class__.__name__,
+ e,
+ )
else:
self.__pool.logger.info(
"%sInvalidate connection %r",
"Soft " if soft else "",
- self.connection)
+ self.connection,
+ )
if soft:
self._soft_invalidate_time = time.time()
else:
@@ -580,24 +591,26 @@ class _ConnectionRecord(object):
if self.connection is None:
self.info.clear()
self.__connect()
- elif self.__pool._recycle > -1 and \
- time.time() - self.starttime > self.__pool._recycle:
+ elif (
+ self.__pool._recycle > -1
+ and time.time() - self.starttime > self.__pool._recycle
+ ):
self.__pool.logger.info(
- "Connection %r exceeded timeout; recycling",
- self.connection)
+ "Connection %r exceeded timeout; recycling", self.connection
+ )
recycle = True
elif self.__pool._invalidate_time > self.starttime:
self.__pool.logger.info(
- "Connection %r invalidated due to pool invalidation; " +
- "recycling",
- self.connection
+ "Connection %r invalidated due to pool invalidation; "
+ + "recycling",
+ self.connection,
)
recycle = True
elif self._soft_invalidate_time > self.starttime:
self.__pool.logger.info(
- "Connection %r invalidated due to local soft invalidation; " +
- "recycling",
- self.connection
+ "Connection %r invalidated due to local soft invalidation; "
+ + "recycling",
+ self.connection,
)
recycle = True
@@ -631,15 +644,16 @@ class _ConnectionRecord(object):
raise
else:
if first_connect_check:
- pool.dispatch.first_connect.\
- for_modify(pool.dispatch).\
- exec_once(self.connection, self)
+ pool.dispatch.first_connect.for_modify(
+ pool.dispatch
+ ).exec_once(self.connection, self)
if pool.dispatch.connect:
pool.dispatch.connect(self.connection, self)
-def _finalize_fairy(connection, connection_record,
- pool, ref, echo, fairy=None):
+def _finalize_fairy(
+ connection, connection_record, pool, ref, echo, fairy=None
+):
"""Cleanup for a :class:`._ConnectionFairy` whether or not it's already
been garbage collected.
@@ -654,12 +668,14 @@ def _finalize_fairy(connection, connection_record,
if connection is not None:
if connection_record and echo:
- pool.logger.debug("Connection %r being returned to pool",
- connection)
+ pool.logger.debug(
+ "Connection %r being returned to pool", connection
+ )
try:
fairy = fairy or _ConnectionFairy(
- connection, connection_record, echo)
+ connection, connection_record, echo
+ )
assert fairy.connection is connection
fairy._reset(pool)
@@ -670,7 +686,8 @@ def _finalize_fairy(connection, connection_record,
pool._close_connection(connection)
except BaseException as e:
pool.logger.error(
- "Exception during reset or similar", exc_info=True)
+ "Exception during reset or similar", exc_info=True
+ )
if connection_record:
connection_record.invalidate(e=e)
if not isinstance(e, Exception):
@@ -752,8 +769,9 @@ class _ConnectionFairy(object):
raise exc.InvalidRequestError("This connection is closed")
fairy._counter += 1
- if (not pool.dispatch.checkout and not pool._pre_ping) or \
- fairy._counter != 1:
+ if (
+ not pool.dispatch.checkout and not pool._pre_ping
+ ) or fairy._counter != 1:
return fairy
# Pool listeners can trigger a reconnection on checkout, as well
@@ -767,38 +785,45 @@ class _ConnectionFairy(object):
if pool._pre_ping:
if fairy._echo:
pool.logger.debug(
- "Pool pre-ping on connection %s",
- fairy.connection)
+ "Pool pre-ping on connection %s", fairy.connection
+ )
result = pool._dialect.do_ping(fairy.connection)
if not result:
if fairy._echo:
pool.logger.debug(
"Pool pre-ping on connection %s failed, "
- "will invalidate pool", fairy.connection)
+ "will invalidate pool",
+ fairy.connection,
+ )
raise exc.InvalidatePoolError()
- pool.dispatch.checkout(fairy.connection,
- fairy._connection_record,
- fairy)
+ pool.dispatch.checkout(
+ fairy.connection, fairy._connection_record, fairy
+ )
return fairy
except exc.DisconnectionError as e:
if e.invalidate_pool:
pool.logger.info(
"Disconnection detected on checkout, "
"invalidating all pooled connections prior to "
- "current timestamp (reason: %r)", e)
+ "current timestamp (reason: %r)",
+ e,
+ )
fairy._connection_record.invalidate(e)
pool._invalidate(fairy, e, _checkin=False)
else:
pool.logger.info(
"Disconnection detected on checkout, "
"invalidating individual connection %s (reason: %r)",
- fairy.connection, e)
+ fairy.connection,
+ e,
+ )
fairy._connection_record.invalidate(e)
try:
- fairy.connection = \
+ fairy.connection = (
fairy._connection_record.get_connection()
+ )
except Exception as err:
with util.safe_reraise():
fairy._connection_record._checkin_failed(err)
@@ -813,8 +838,14 @@ class _ConnectionFairy(object):
return _ConnectionFairy._checkout(self._pool, fairy=self)
def _checkin(self):
- _finalize_fairy(self.connection, self._connection_record,
- self._pool, None, self._echo, fairy=self)
+ _finalize_fairy(
+ self.connection,
+ self._connection_record,
+ self._pool,
+ None,
+ self._echo,
+ fairy=self,
+ )
self.connection = None
self._connection_record = None
@@ -825,20 +856,22 @@ class _ConnectionFairy(object):
pool.dispatch.reset(self, self._connection_record)
if pool._reset_on_return is reset_rollback:
if self._echo:
- pool.logger.debug("Connection %s rollback-on-return%s",
- self.connection,
- ", via agent"
- if self._reset_agent else "")
+ pool.logger.debug(
+ "Connection %s rollback-on-return%s",
+ self.connection,
+ ", via agent" if self._reset_agent else "",
+ )
if self._reset_agent:
self._reset_agent.rollback()
else:
pool._dialect.do_rollback(self)
elif pool._reset_on_return is reset_commit:
if self._echo:
- pool.logger.debug("Connection %s commit-on-return%s",
- self.connection,
- ", via agent"
- if self._reset_agent else "")
+ pool.logger.debug(
+ "Connection %s commit-on-return%s",
+ self.connection,
+ ", via agent" if self._reset_agent else "",
+ )
if self._reset_agent:
self._reset_agent.commit()
else:
@@ -964,5 +997,3 @@ class _ConnectionFairy(object):
self._counter -= 1
if self._counter == 0:
self._checkin()
-
-
diff --git a/lib/sqlalchemy/pool/dbapi_proxy.py b/lib/sqlalchemy/pool/dbapi_proxy.py
index aa439bd23..425c4a114 100644
--- a/lib/sqlalchemy/pool/dbapi_proxy.py
+++ b/lib/sqlalchemy/pool/dbapi_proxy.py
@@ -101,9 +101,10 @@ class _DBProxy(object):
self._create_pool_mutex.acquire()
try:
if key not in self.pools:
- kw.pop('sa_pool_key', None)
+ kw.pop("sa_pool_key", None)
pool = self.poolclass(
- lambda: self.module.connect(*args, **kw), **self.kw)
+ lambda: self.module.connect(*args, **kw), **self.kw
+ )
self.pools[key] = pool
return pool
else:
@@ -138,9 +139,6 @@ class _DBProxy(object):
def _serialize(self, *args, **kw):
if "sa_pool_key" in kw:
- return kw['sa_pool_key']
+ return kw["sa_pool_key"]
- return tuple(
- list(args) +
- [(k, kw[k]) for k in sorted(kw)]
- )
+ return tuple(list(args) + [(k, kw[k]) for k in sorted(kw)])
diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py
index 3058d6247..6159f6a5b 100644
--- a/lib/sqlalchemy/pool/impl.py
+++ b/lib/sqlalchemy/pool/impl.py
@@ -30,8 +30,15 @@ class QueuePool(Pool):
"""
- def __init__(self, creator, pool_size=5, max_overflow=10, timeout=30, use_lifo=False,
- **kw):
+ def __init__(
+ self,
+ creator,
+ pool_size=5,
+ max_overflow=10,
+ timeout=30,
+ use_lifo=False,
+ **kw
+ ):
r"""
Construct a QueuePool.
@@ -117,8 +124,10 @@ class QueuePool(Pool):
else:
raise exc.TimeoutError(
"QueuePool limit of size %d overflow %d reached, "
- "connection timed out, timeout %d" %
- (self.size(), self.overflow(), self._timeout), code="3o7r")
+ "connection timed out, timeout %d"
+ % (self.size(), self.overflow(), self._timeout),
+ code="3o7r",
+ )
if self._inc_overflow():
try:
@@ -150,15 +159,19 @@ class QueuePool(Pool):
def recreate(self):
self.logger.info("Pool recreating")
- return self.__class__(self._creator, pool_size=self._pool.maxsize,
- max_overflow=self._max_overflow,
- timeout=self._timeout,
- recycle=self._recycle, echo=self.echo,
- logging_name=self._orig_logging_name,
- use_threadlocal=self._use_threadlocal,
- reset_on_return=self._reset_on_return,
- _dispatch=self.dispatch,
- dialect=self._dialect)
+ return self.__class__(
+ self._creator,
+ pool_size=self._pool.maxsize,
+ max_overflow=self._max_overflow,
+ timeout=self._timeout,
+ recycle=self._recycle,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ use_threadlocal=self._use_threadlocal,
+ reset_on_return=self._reset_on_return,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
def dispose(self):
while True:
@@ -172,12 +185,17 @@ class QueuePool(Pool):
self.logger.info("Pool disposed. %s", self.status())
def status(self):
- return "Pool size: %d Connections in pool: %d "\
- "Current Overflow: %d Current Checked out "\
- "connections: %d" % (self.size(),
- self.checkedin(),
- self.overflow(),
- self.checkedout())
+ return (
+ "Pool size: %d Connections in pool: %d "
+ "Current Overflow: %d Current Checked out "
+ "connections: %d"
+ % (
+ self.size(),
+ self.checkedin(),
+ self.overflow(),
+ self.checkedout(),
+ )
+ )
def size(self):
return self._pool.maxsize
@@ -221,14 +239,16 @@ class NullPool(Pool):
def recreate(self):
self.logger.info("Pool recreating")
- return self.__class__(self._creator,
- recycle=self._recycle,
- echo=self.echo,
- logging_name=self._orig_logging_name,
- use_threadlocal=self._use_threadlocal,
- reset_on_return=self._reset_on_return,
- _dispatch=self.dispatch,
- dialect=self._dialect)
+ return self.__class__(
+ self._creator,
+ recycle=self._recycle,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ use_threadlocal=self._use_threadlocal,
+ reset_on_return=self._reset_on_return,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
def dispose(self):
pass
@@ -266,7 +286,7 @@ class SingletonThreadPool(Pool):
"""
def __init__(self, creator, pool_size=5, **kw):
- kw['use_threadlocal'] = True
+ kw["use_threadlocal"] = True
Pool.__init__(self, creator, **kw)
self._conn = threading.local()
self._all_conns = set()
@@ -274,15 +294,17 @@ class SingletonThreadPool(Pool):
def recreate(self):
self.logger.info("Pool recreating")
- return self.__class__(self._creator,
- pool_size=self.size,
- recycle=self._recycle,
- echo=self.echo,
- logging_name=self._orig_logging_name,
- use_threadlocal=self._use_threadlocal,
- reset_on_return=self._reset_on_return,
- _dispatch=self.dispatch,
- dialect=self._dialect)
+ return self.__class__(
+ self._creator,
+ pool_size=self.size,
+ recycle=self._recycle,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ use_threadlocal=self._use_threadlocal,
+ reset_on_return=self._reset_on_return,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
def dispose(self):
"""Dispose of this pool."""
@@ -303,8 +325,10 @@ class SingletonThreadPool(Pool):
c.close()
def status(self):
- return "SingletonThreadPool id:%d size: %d" % \
- (id(self), len(self._all_conns))
+ return "SingletonThreadPool id:%d size: %d" % (
+ id(self),
+ len(self._all_conns),
+ )
def _do_return_conn(self, conn):
pass
@@ -347,20 +371,22 @@ class StaticPool(Pool):
return "StaticPool"
def dispose(self):
- if '_conn' in self.__dict__:
+ if "_conn" in self.__dict__:
self._conn.close()
self._conn = None
def recreate(self):
self.logger.info("Pool recreating")
- return self.__class__(creator=self._creator,
- recycle=self._recycle,
- use_threadlocal=self._use_threadlocal,
- reset_on_return=self._reset_on_return,
- echo=self.echo,
- logging_name=self._orig_logging_name,
- _dispatch=self.dispatch,
- dialect=self._dialect)
+ return self.__class__(
+ creator=self._creator,
+ recycle=self._recycle,
+ use_threadlocal=self._use_threadlocal,
+ reset_on_return=self._reset_on_return,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
def _create_connection(self):
return self._conn
@@ -391,7 +417,7 @@ class AssertionPool(Pool):
def __init__(self, *args, **kw):
self._conn = None
self._checked_out = False
- self._store_traceback = kw.pop('store_traceback', True)
+ self._store_traceback = kw.pop("store_traceback", True)
self._checkout_traceback = None
Pool.__init__(self, *args, **kw)
@@ -411,18 +437,22 @@ class AssertionPool(Pool):
def recreate(self):
self.logger.info("Pool recreating")
- return self.__class__(self._creator, echo=self.echo,
- logging_name=self._orig_logging_name,
- _dispatch=self.dispatch,
- dialect=self._dialect)
+ return self.__class__(
+ self._creator,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
def _do_get(self):
if self._checked_out:
if self._checkout_traceback:
- suffix = ' at:\n%s' % ''.join(
- chop_traceback(self._checkout_traceback))
+ suffix = " at:\n%s" % "".join(
+ chop_traceback(self._checkout_traceback)
+ )
else:
- suffix = ''
+ suffix = ""
raise AssertionError("connection is already checked out" + suffix)
if not self._conn:
diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py
index 860a55b8f..46d5dcbc6 100644
--- a/lib/sqlalchemy/processors.py
+++ b/lib/sqlalchemy/processors.py
@@ -32,20 +32,30 @@ def str_to_datetime_processor_factory(regexp, type_):
try:
m = rmatch(value)
except TypeError:
- raise ValueError("Couldn't parse %s string '%r' "
- "- value is not a string." %
- (type_.__name__, value))
+ raise ValueError(
+ "Couldn't parse %s string '%r' "
+ "- value is not a string." % (type_.__name__, value)
+ )
if m is None:
- raise ValueError("Couldn't parse %s string: "
- "'%s'" % (type_.__name__, value))
+ raise ValueError(
+ "Couldn't parse %s string: "
+ "'%s'" % (type_.__name__, value)
+ )
if has_named_groups:
groups = m.groupdict(0)
- return type_(**dict(list(zip(
- iter(groups.keys()),
- list(map(int, iter(groups.values())))
- ))))
+ return type_(
+ **dict(
+ list(
+ zip(
+ iter(groups.keys()),
+ list(map(int, iter(groups.values()))),
+ )
+ )
+ )
+ )
else:
return type_(*list(map(int, m.groups(0))))
+
return process
@@ -61,6 +71,7 @@ def py_fallback():
# len part is safe: it is done that way in the normal
# 'xx'.decode(encoding) code path.
return decoder(value, errors)[0]
+
return process
def to_conditional_unicode_processor_factory(encoding, errors=None):
@@ -76,6 +87,7 @@ def py_fallback():
# len part is safe: it is done that way in the normal
# 'xx'.decode(encoding) code path.
return decoder(value, errors)[0]
+
return process
def to_decimal_processor_factory(target_class, scale):
@@ -86,6 +98,7 @@ def py_fallback():
return None
else:
return target_class(fstring % value)
+
return process
def to_float(value):
@@ -107,22 +120,30 @@ def py_fallback():
return bool(value)
DATETIME_RE = re.compile(
- r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?")
+ r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?"
+ )
TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)")
- str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE,
- datetime.datetime)
+ str_to_datetime = str_to_datetime_processor_factory(
+ DATETIME_RE, datetime.datetime
+ )
str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time)
str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date)
return locals()
+
try:
- from sqlalchemy.cprocessors import UnicodeResultProcessor, \
- DecimalResultProcessor, \
- to_float, to_str, int_to_boolean, \
- str_to_datetime, str_to_time, \
- str_to_date
+ from sqlalchemy.cprocessors import (
+ UnicodeResultProcessor,
+ DecimalResultProcessor,
+ to_float,
+ to_str,
+ int_to_boolean,
+ str_to_datetime,
+ str_to_time,
+ str_to_date,
+ )
def to_unicode_processor_factory(encoding, errors=None):
if errors is not None:
@@ -144,5 +165,6 @@ try:
# return Decimal('5'). These are equivalent of course.
return DecimalResultProcessor(target_class, "%%.%df" % scale).process
+
except ImportError:
globals().update(py_fallback())
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index aa7b4f008..598d499dc 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -9,9 +9,7 @@
"""
-from .sql.base import (
- SchemaVisitor
- )
+from .sql.base import SchemaVisitor
from .sql.schema import (
@@ -36,8 +34,8 @@ from .sql.schema import (
UniqueConstraint,
_get_table_key,
ColumnCollectionConstraint,
- ColumnCollectionMixin
- )
+ ColumnCollectionMixin,
+)
from .sql.naming import conv
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index aa811388b..87e2fb6c3 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -72,7 +72,7 @@ from .expression import (
union,
union_all,
update,
- within_group
+ within_group,
)
from .visitors import ClauseVisitor
@@ -84,12 +84,16 @@ def __go(lcls):
import inspect as _inspect
- __all__ = sorted(name for name, obj in lcls.items()
- if not (name.startswith('_') or _inspect.ismodule(obj)))
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
from .annotation import _prepare_annotations, Annotated
from .elements import AnnotatedColumnElement, ClauseList
from .selectable import AnnotatedFromClause
+
_prepare_annotations(ColumnElement, AnnotatedColumnElement)
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
@@ -98,4 +102,5 @@ def __go(lcls):
from . import naming
+
__go(locals())
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index c1d484d95..64cfa630e 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -76,8 +76,7 @@ class Annotated(object):
return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
- return self.__element.__class__._compiler_dispatch(
- self, visitor, **kw)
+ return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@property
def _constructor(self):
@@ -120,10 +119,13 @@ def _deep_annotate(element, annotations, exclude=None):
Elements within the exclude collection will be cloned but not annotated.
"""
+
def clone(elem):
- if exclude and \
- hasattr(elem, 'proxy_set') and \
- elem.proxy_set.intersection(exclude):
+ if (
+ exclude
+ and hasattr(elem, "proxy_set")
+ and elem.proxy_set.intersection(exclude)
+ ):
newelem = elem._clone()
elif annotations != elem._annotations:
newelem = elem._annotate(annotations)
@@ -191,8 +193,8 @@ def _new_annotation_type(cls, base_cls):
break
annotated_classes[cls] = anno_cls = type(
- "Annotated%s" % cls.__name__,
- (base_cls, cls), {})
+ "Annotated%s" % cls.__name__, (base_cls, cls), {}
+ )
globals()["Annotated%s" % cls.__name__] = anno_cls
return anno_cls
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 6b9b55753..45db215fe 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -15,8 +15,8 @@ import itertools
from .visitors import ClauseVisitor
import re
-PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT')
-NO_ARG = util.symbol('NO_ARG')
+PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
+NO_ARG = util.symbol("NO_ARG")
class Immutable(object):
@@ -77,7 +77,8 @@ class _DialectArgView(util.collections_abc.MutableMapping):
dialect, value_key = self._key(key)
except KeyError:
raise exc.ArgumentError(
- "Keys must be of the form <dialectname>_<argname>")
+ "Keys must be of the form <dialectname>_<argname>"
+ )
else:
self.obj.dialect_options[dialect][value_key] = value
@@ -86,15 +87,18 @@ class _DialectArgView(util.collections_abc.MutableMapping):
del self.obj.dialect_options[dialect][value_key]
def __len__(self):
- return sum(len(args._non_defaults) for args in
- self.obj.dialect_options.values())
+ return sum(
+ len(args._non_defaults)
+ for args in self.obj.dialect_options.values()
+ )
def __iter__(self):
return (
util.safe_kwarg("%s_%s" % (dialect_name, value_name))
for dialect_name in self.obj.dialect_options
- for value_name in
- self.obj.dialect_options[dialect_name]._non_defaults
+ for value_name in self.obj.dialect_options[
+ dialect_name
+ ]._non_defaults
)
@@ -187,8 +191,8 @@ class DialectKWArgs(object):
if construct_arg_dictionary is None:
raise exc.ArgumentError(
"Dialect '%s' does have keyword-argument "
- "validation and defaults enabled configured" %
- dialect_name)
+ "validation and defaults enabled configured" % dialect_name
+ )
if cls not in construct_arg_dictionary:
construct_arg_dictionary[cls] = {}
construct_arg_dictionary[cls][argument_name] = default
@@ -230,6 +234,7 @@ class DialectKWArgs(object):
if dialect_cls.construct_arguments is None:
return None
return dict(dialect_cls.construct_arguments)
+
_kw_registry = util.PopulateDict(_kw_reg_for_dialect)
def _kw_reg_for_dialect_cls(self, dialect_name):
@@ -274,11 +279,12 @@ class DialectKWArgs(object):
return
for k in kwargs:
- m = re.match('^(.+?)_(.+)$', k)
+ m = re.match("^(.+?)_(.+)$", k)
if not m:
raise TypeError(
"Additional arguments should be "
- "named <dialectname>_<argument>, got '%s'" % k)
+ "named <dialectname>_<argument>, got '%s'" % k
+ )
dialect_name, arg_name = m.group(1, 2)
try:
@@ -286,20 +292,22 @@ class DialectKWArgs(object):
except exc.NoSuchModuleError:
util.warn(
"Can't validate argument %r; can't "
- "locate any SQLAlchemy dialect named %r" %
- (k, dialect_name))
+ "locate any SQLAlchemy dialect named %r"
+ % (k, dialect_name)
+ )
self.dialect_options[dialect_name] = d = _DialectArgDict()
d._defaults.update({"*": None})
d._non_defaults[arg_name] = kwargs[k]
else:
- if "*" not in construct_arg_dictionary and \
- arg_name not in construct_arg_dictionary:
+ if (
+ "*" not in construct_arg_dictionary
+ and arg_name not in construct_arg_dictionary
+ ):
raise exc.ArgumentError(
"Argument %r is not accepted by "
- "dialect %r on behalf of %r" % (
- k,
- dialect_name, self.__class__
- ))
+ "dialect %r on behalf of %r"
+ % (k, dialect_name, self.__class__)
+ )
else:
construct_arg_dictionary[arg_name] = kwargs[k]
@@ -359,14 +367,14 @@ class Executable(Generative):
:meth:`.Query.execution_options()`
"""
- if 'isolation_level' in kw:
+ if "isolation_level" in kw:
raise exc.ArgumentError(
"'isolation_level' execution option may only be specified "
"on Connection.execution_options(), or "
"per-engine using the isolation_level "
"argument to create_engine()."
)
- if 'compiled_cache' in kw:
+ if "compiled_cache" in kw:
raise exc.ArgumentError(
"'compiled_cache' execution option may only be specified "
"on Connection.execution_options(), not per statement."
@@ -377,10 +385,12 @@ class Executable(Generative):
"""Compile and execute this :class:`.Executable`."""
e = self.bind
if e is None:
- label = getattr(self, 'description', self.__class__.__name__)
- msg = ('This %s is not directly bound to a Connection or Engine. '
- 'Use the .execute() method of a Connection or Engine '
- 'to execute this construct.' % label)
+ label = getattr(self, "description", self.__class__.__name__)
+ msg = (
+ "This %s is not directly bound to a Connection or Engine. "
+ "Use the .execute() method of a Connection or Engine "
+ "to execute this construct." % label
+ )
raise exc.UnboundExecutionError(msg)
return e._execute_clauseelement(self, multiparams, params)
@@ -434,7 +444,7 @@ class SchemaEventTarget(object):
class SchemaVisitor(ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
- __traverse_options__ = {'schema_visitor': True}
+ __traverse_options__ = {"schema_visitor": True}
class ColumnCollection(util.OrderedProperties):
@@ -446,11 +456,11 @@ class ColumnCollection(util.OrderedProperties):
"""
- __slots__ = '_all_columns'
+ __slots__ = "_all_columns"
def __init__(self, *columns):
super(ColumnCollection, self).__init__()
- object.__setattr__(self, '_all_columns', [])
+ object.__setattr__(self, "_all_columns", [])
for c in columns:
self.add(c)
@@ -485,8 +495,9 @@ class ColumnCollection(util.OrderedProperties):
self._data[column.key] = column
if remove_col is not None:
- self._all_columns[:] = [column if c is remove_col
- else c for c in self._all_columns]
+ self._all_columns[:] = [
+ column if c is remove_col else c for c in self._all_columns
+ ]
else:
self._all_columns.append(column)
@@ -499,7 +510,8 @@ class ColumnCollection(util.OrderedProperties):
"""
if not column.key:
raise exc.ArgumentError(
- "Can't add unnamed column to column collection")
+ "Can't add unnamed column to column collection"
+ )
self[column.key] = column
def __delitem__(self, key):
@@ -521,10 +533,12 @@ class ColumnCollection(util.OrderedProperties):
return
if not existing.shares_lineage(value):
- util.warn('Column %r on table %r being replaced by '
- '%r, which has the same key. Consider '
- 'use_labels for select() statements.' %
- (key, getattr(existing, 'table', None), value))
+ util.warn(
+ "Column %r on table %r being replaced by "
+ "%r, which has the same key. Consider "
+ "use_labels for select() statements."
+ % (key, getattr(existing, "table", None), value)
+ )
# pop out memoized proxy_set as this
# operation may very well be occurring
@@ -540,13 +554,15 @@ class ColumnCollection(util.OrderedProperties):
def remove(self, column):
del self._data[column.key]
self._all_columns[:] = [
- c for c in self._all_columns if c is not column]
+ c for c in self._all_columns if c is not column
+ ]
def update(self, iter):
cols = list(iter)
all_col_set = set(self._all_columns)
self._all_columns.extend(
- c for label, c in cols if c not in all_col_set)
+ c for label, c in cols if c not in all_col_set
+ )
self._data.update((label, c) for label, c in cols)
def extend(self, iter):
@@ -572,12 +588,11 @@ class ColumnCollection(util.OrderedProperties):
return util.OrderedProperties.__contains__(self, other)
def __getstate__(self):
- return {'_data': self._data,
- '_all_columns': self._all_columns}
+ return {"_data": self._data, "_all_columns": self._all_columns}
def __setstate__(self, state):
- object.__setattr__(self, '_data', state['_data'])
- object.__setattr__(self, '_all_columns', state['_all_columns'])
+ object.__setattr__(self, "_data", state["_data"])
+ object.__setattr__(self, "_all_columns", state["_all_columns"])
def contains_column(self, col):
return col in set(self._all_columns)
@@ -589,7 +604,7 @@ class ColumnCollection(util.OrderedProperties):
class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
def __init__(self, data, all_columns):
util.ImmutableProperties.__init__(self, data)
- object.__setattr__(self, '_all_columns', all_columns)
+ object.__setattr__(self, "_all_columns", all_columns)
extend = remove = util.ImmutableProperties._immutable
@@ -622,15 +637,18 @@ def _bind_or_error(schemaitem, msg=None):
bind = schemaitem.bind
if not bind:
name = schemaitem.__class__.__name__
- label = getattr(schemaitem, 'fullname',
- getattr(schemaitem, 'name', None))
+ label = getattr(
+ schemaitem, "fullname", getattr(schemaitem, "name", None)
+ )
if label:
- item = '%s object %r' % (name, label)
+ item = "%s object %r" % (name, label)
else:
- item = '%s object' % name
+ item = "%s object" % name
if msg is None:
- msg = "%s is not bound to an Engine or Connection. "\
- "Execution can not proceed without a database to execute "\
+ msg = (
+ "%s is not bound to an Engine or Connection. "
+ "Execution can not proceed without a database to execute "
"against." % item
+ )
raise exc.UnboundExecutionError(msg)
return bind
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 80ed707ed..f641d0a84 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -25,133 +25,218 @@ To generate user-defined SQL strings, see
import contextlib
import re
-from . import schema, sqltypes, operators, functions, visitors, \
- elements, selectable, crud
+from . import (
+ schema,
+ sqltypes,
+ operators,
+ functions,
+ visitors,
+ elements,
+ selectable,
+ crud,
+)
from .. import util, exc
import itertools
-RESERVED_WORDS = set([
- 'all', 'analyse', 'analyze', 'and', 'any', 'array',
- 'as', 'asc', 'asymmetric', 'authorization', 'between',
- 'binary', 'both', 'case', 'cast', 'check', 'collate',
- 'column', 'constraint', 'create', 'cross', 'current_date',
- 'current_role', 'current_time', 'current_timestamp',
- 'current_user', 'default', 'deferrable', 'desc',
- 'distinct', 'do', 'else', 'end', 'except', 'false',
- 'for', 'foreign', 'freeze', 'from', 'full', 'grant',
- 'group', 'having', 'ilike', 'in', 'initially', 'inner',
- 'intersect', 'into', 'is', 'isnull', 'join', 'leading',
- 'left', 'like', 'limit', 'localtime', 'localtimestamp',
- 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset',
- 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps',
- 'placing', 'primary', 'references', 'right', 'select',
- 'session_user', 'set', 'similar', 'some', 'symmetric', 'table',
- 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user',
- 'using', 'verbose', 'when', 'where'])
-
-LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
-ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(['$'])
-
-BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
-BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE)
+RESERVED_WORDS = set(
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "authorization",
+ "between",
+ "binary",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "for",
+ "foreign",
+ "freeze",
+ "from",
+ "full",
+ "grant",
+ "group",
+ "having",
+ "ilike",
+ "in",
+ "initially",
+ "inner",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "leading",
+ "left",
+ "like",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "natural",
+ "new",
+ "not",
+ "notnull",
+ "null",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "outer",
+ "overlaps",
+ "placing",
+ "primary",
+ "references",
+ "right",
+ "select",
+ "session_user",
+ "set",
+ "similar",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "verbose",
+ "when",
+ "where",
+ ]
+)
+
+LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+
+BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
BIND_TEMPLATES = {
- 'pyformat': "%%(%(name)s)s",
- 'qmark': "?",
- 'format': "%%s",
- 'numeric': ":[_POSITION]",
- 'named': ":%(name)s"
+ "pyformat": "%%(%(name)s)s",
+ "qmark": "?",
+ "format": "%%s",
+ "numeric": ":[_POSITION]",
+ "named": ":%(name)s",
}
OPERATORS = {
# binary
- operators.and_: ' AND ',
- operators.or_: ' OR ',
- operators.add: ' + ',
- operators.mul: ' * ',
- operators.sub: ' - ',
- operators.div: ' / ',
- operators.mod: ' % ',
- operators.truediv: ' / ',
- operators.neg: '-',
- operators.lt: ' < ',
- operators.le: ' <= ',
- operators.ne: ' != ',
- operators.gt: ' > ',
- operators.ge: ' >= ',
- operators.eq: ' = ',
- operators.is_distinct_from: ' IS DISTINCT FROM ',
- operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ',
- operators.concat_op: ' || ',
- operators.match_op: ' MATCH ',
- operators.notmatch_op: ' NOT MATCH ',
- operators.in_op: ' IN ',
- operators.notin_op: ' NOT IN ',
- operators.comma_op: ', ',
- operators.from_: ' FROM ',
- operators.as_: ' AS ',
- operators.is_: ' IS ',
- operators.isnot: ' IS NOT ',
- operators.collate: ' COLLATE ',
-
+ operators.and_: " AND ",
+ operators.or_: " OR ",
+ operators.add: " + ",
+ operators.mul: " * ",
+ operators.sub: " - ",
+ operators.div: " / ",
+ operators.mod: " % ",
+ operators.truediv: " / ",
+ operators.neg: "-",
+ operators.lt: " < ",
+ operators.le: " <= ",
+ operators.ne: " != ",
+ operators.gt: " > ",
+ operators.ge: " >= ",
+ operators.eq: " = ",
+ operators.is_distinct_from: " IS DISTINCT FROM ",
+ operators.isnot_distinct_from: " IS NOT DISTINCT FROM ",
+ operators.concat_op: " || ",
+ operators.match_op: " MATCH ",
+ operators.notmatch_op: " NOT MATCH ",
+ operators.in_op: " IN ",
+ operators.notin_op: " NOT IN ",
+ operators.comma_op: ", ",
+ operators.from_: " FROM ",
+ operators.as_: " AS ",
+ operators.is_: " IS ",
+ operators.isnot: " IS NOT ",
+ operators.collate: " COLLATE ",
# unary
- operators.exists: 'EXISTS ',
- operators.distinct_op: 'DISTINCT ',
- operators.inv: 'NOT ',
- operators.any_op: 'ANY ',
- operators.all_op: 'ALL ',
-
+ operators.exists: "EXISTS ",
+ operators.distinct_op: "DISTINCT ",
+ operators.inv: "NOT ",
+ operators.any_op: "ANY ",
+ operators.all_op: "ALL ",
# modifiers
- operators.desc_op: ' DESC',
- operators.asc_op: ' ASC',
- operators.nullsfirst_op: ' NULLS FIRST',
- operators.nullslast_op: ' NULLS LAST',
-
+ operators.desc_op: " DESC",
+ operators.asc_op: " ASC",
+ operators.nullsfirst_op: " NULLS FIRST",
+ operators.nullslast_op: " NULLS LAST",
}
FUNCTIONS = {
- functions.coalesce: 'coalesce',
- functions.current_date: 'CURRENT_DATE',
- functions.current_time: 'CURRENT_TIME',
- functions.current_timestamp: 'CURRENT_TIMESTAMP',
- functions.current_user: 'CURRENT_USER',
- functions.localtime: 'LOCALTIME',
- functions.localtimestamp: 'LOCALTIMESTAMP',
- functions.random: 'random',
- functions.sysdate: 'sysdate',
- functions.session_user: 'SESSION_USER',
- functions.user: 'USER',
- functions.cube: 'CUBE',
- functions.rollup: 'ROLLUP',
- functions.grouping_sets: 'GROUPING SETS',
+ functions.coalesce: "coalesce",
+ functions.current_date: "CURRENT_DATE",
+ functions.current_time: "CURRENT_TIME",
+ functions.current_timestamp: "CURRENT_TIMESTAMP",
+ functions.current_user: "CURRENT_USER",
+ functions.localtime: "LOCALTIME",
+ functions.localtimestamp: "LOCALTIMESTAMP",
+ functions.random: "random",
+ functions.sysdate: "sysdate",
+ functions.session_user: "SESSION_USER",
+ functions.user: "USER",
+ functions.cube: "CUBE",
+ functions.rollup: "ROLLUP",
+ functions.grouping_sets: "GROUPING SETS",
}
EXTRACT_MAP = {
- 'month': 'month',
- 'day': 'day',
- 'year': 'year',
- 'second': 'second',
- 'hour': 'hour',
- 'doy': 'doy',
- 'minute': 'minute',
- 'quarter': 'quarter',
- 'dow': 'dow',
- 'week': 'week',
- 'epoch': 'epoch',
- 'milliseconds': 'milliseconds',
- 'microseconds': 'microseconds',
- 'timezone_hour': 'timezone_hour',
- 'timezone_minute': 'timezone_minute'
+ "month": "month",
+ "day": "day",
+ "year": "year",
+ "second": "second",
+ "hour": "hour",
+ "doy": "doy",
+ "minute": "minute",
+ "quarter": "quarter",
+ "dow": "dow",
+ "week": "week",
+ "epoch": "epoch",
+ "milliseconds": "milliseconds",
+ "microseconds": "microseconds",
+ "timezone_hour": "timezone_hour",
+ "timezone_minute": "timezone_minute",
}
COMPOUND_KEYWORDS = {
- selectable.CompoundSelect.UNION: 'UNION',
- selectable.CompoundSelect.UNION_ALL: 'UNION ALL',
- selectable.CompoundSelect.EXCEPT: 'EXCEPT',
- selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
- selectable.CompoundSelect.INTERSECT: 'INTERSECT',
- selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
+ selectable.CompoundSelect.UNION: "UNION",
+ selectable.CompoundSelect.UNION_ALL: "UNION ALL",
+ selectable.CompoundSelect.EXCEPT: "EXCEPT",
+ selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
+ selectable.CompoundSelect.INTERSECT: "INTERSECT",
+ selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
}
@@ -177,9 +262,14 @@ class Compiled(object):
sub-elements of the statement can modify these.
"""
- def __init__(self, dialect, statement, bind=None,
- schema_translate_map=None,
- compile_kwargs=util.immutabledict()):
+ def __init__(
+ self,
+ dialect,
+ statement,
+ bind=None,
+ schema_translate_map=None,
+ compile_kwargs=util.immutabledict(),
+ ):
"""Construct a new :class:`.Compiled` object.
:param dialect: :class:`.Dialect` to compile against.
@@ -209,7 +299,8 @@ class Compiled(object):
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
self.preparer = self.preparer._with_schema_translate(
- schema_translate_map)
+ schema_translate_map
+ )
if statement is not None:
self.statement = statement
@@ -218,8 +309,10 @@ class Compiled(object):
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
- @util.deprecated("0.7", ":class:`.Compiled` objects now compile "
- "within the constructor.")
+ @util.deprecated(
+ "0.7",
+ ":class:`.Compiled` objects now compile " "within the constructor.",
+ )
def compile(self):
"""Produce the internal string representation of this element.
"""
@@ -247,7 +340,7 @@ class Compiled(object):
def __str__(self):
"""Return the string text of the generated SQL or DDL."""
- return self.string or ''
+ return self.string or ""
def construct_params(self, params=None):
"""Return the bind params for this compiled object.
@@ -271,7 +364,9 @@ class Compiled(object):
if e is None:
raise exc.UnboundExecutionError(
"This Compiled object is not bound to any Engine "
- "or Connection.", code="2afi")
+ "or Connection.",
+ code="2afi",
+ )
return e._execute_compiled(self, multiparams, params)
def scalar(self, *multiparams, **params):
@@ -284,7 +379,7 @@ class Compiled(object):
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
- ensure_kwarg = r'visit_\w+'
+ ensure_kwarg = r"visit_\w+"
def __init__(self, dialect):
self.dialect = dialect
@@ -297,8 +392,8 @@ class _CompileLabel(visitors.Visitable):
"""lightweight label object which acts as an expression.Label."""
- __visit_name__ = 'label'
- __slots__ = 'element', 'name'
+ __visit_name__ = "label"
+ __slots__ = "element", "name"
def __init__(self, col, name, alt_names=()):
self.element = col
@@ -390,8 +485,9 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
- def __init__(self, dialect, statement, column_keys=None,
- inline=False, **kwargs):
+ def __init__(
+ self, dialect, statement, column_keys=None, inline=False, **kwargs
+ ):
"""Construct a new :class:`.SQLCompiler` object.
:param dialect: :class:`.Dialect` to be used
@@ -412,7 +508,7 @@ class SQLCompiler(Compiled):
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
- self.inline = inline or getattr(statement, 'inline', False)
+ self.inline = inline or getattr(statement, "inline", False)
# a dictionary of bind parameter keys to BindParameter
# instances.
@@ -440,8 +536,9 @@ class SQLCompiler(Compiled):
self.ctes = None
- self.label_length = dialect.label_length \
- or dialect.max_identifier_length
+ self.label_length = (
+ dialect.label_length or dialect.max_identifier_length
+ )
# a map which tracks "anonymous" identifiers that are created on
# the fly here
@@ -453,7 +550,7 @@ class SQLCompiler(Compiled):
Compiled.__init__(self, dialect, statement, **kwargs)
if (
- self.isinsert or self.isupdate or self.isdelete
+ self.isinsert or self.isupdate or self.isdelete
) and statement._returning:
self.returning = statement._returning
@@ -482,37 +579,43 @@ class SQLCompiler(Compiled):
def _nested_result(self):
"""special API to support the use case of 'nested result sets'"""
result_columns, ordered_columns = (
- self._result_columns, self._ordered_columns)
+ self._result_columns,
+ self._ordered_columns,
+ )
self._result_columns, self._ordered_columns = [], False
try:
if self.stack:
entry = self.stack[-1]
- entry['need_result_map_for_nested'] = True
+ entry["need_result_map_for_nested"] = True
else:
entry = None
yield self._result_columns, self._ordered_columns
finally:
if entry:
- entry.pop('need_result_map_for_nested')
+ entry.pop("need_result_map_for_nested")
self._result_columns, self._ordered_columns = (
- result_columns, ordered_columns)
+ result_columns,
+ ordered_columns,
+ )
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
- r'\[_POSITION\]',
- lambda m: str(util.next(poscount)),
- self.string)
+ r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string
+ )
@util.memoized_property
def _bind_processors(self):
return dict(
- (key, value) for key, value in
- ((self.bind_names[bindparam],
- bindparam.type._cached_bind_processor(self.dialect)
- )
- for bindparam in self.bind_names)
+ (key, value)
+ for key, value in (
+ (
+ self.bind_names[bindparam],
+ bindparam.type._cached_bind_processor(self.dialect),
+ )
+ for bindparam in self.bind_names
+ )
if value is not None
)
@@ -539,12 +642,16 @@ class SQLCompiler(Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
- (bindparam.key, _group_number), code="cd3x")
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
- % bindparam.key, code="cd3x")
+ % bindparam.key,
+ code="cd3x",
+ )
elif bindparam.callable:
pd[name] = bindparam.effective_value
@@ -558,12 +665,16 @@ class SQLCompiler(Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
- (bindparam.key, _group_number), code="cd3x")
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
- % bindparam.key, code="cd3x")
+ % bindparam.key,
+ code="cd3x",
+ )
if bindparam.callable:
pd[self.bind_names[bindparam]] = bindparam.effective_value
@@ -595,9 +706,10 @@ class SQLCompiler(Compiled):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
def visit_label_reference(
- self, element, within_columns_clause=False, **kwargs):
+ self, element, within_columns_clause=False, **kwargs
+ ):
if self.stack and self.dialect.supports_simple_order_by_label:
- selectable = self.stack[-1]['selectable']
+ selectable = self.stack[-1]["selectable"]
with_cols, only_froms, only_cols = selectable._label_resolve_dict
if within_columns_clause:
@@ -611,25 +723,30 @@ class SQLCompiler(Compiled):
# to something else like a ColumnClause expression.
order_by_elem = element.element._order_by_label_element
- if order_by_elem is not None and order_by_elem.name in \
- resolve_dict and \
- order_by_elem.shares_lineage(
- resolve_dict[order_by_elem.name]):
- kwargs['render_label_as_label'] = \
- element.element._order_by_label_element
+ if (
+ order_by_elem is not None
+ and order_by_elem.name in resolve_dict
+ and order_by_elem.shares_lineage(
+ resolve_dict[order_by_elem.name]
+ )
+ ):
+ kwargs[
+ "render_label_as_label"
+ ] = element.element._order_by_label_element
return self.process(
- element.element, within_columns_clause=within_columns_clause,
- **kwargs)
+ element.element,
+ within_columns_clause=within_columns_clause,
+ **kwargs
+ )
def visit_textual_label_reference(
- self, element, within_columns_clause=False, **kwargs):
+ self, element, within_columns_clause=False, **kwargs
+ ):
if not self.stack:
# compiling the element outside of the context of a SELECT
- return self.process(
- element._text_clause
- )
+ return self.process(element._text_clause)
- selectable = self.stack[-1]['selectable']
+ selectable = self.stack[-1]["selectable"]
with_cols, only_froms, only_cols = selectable._label_resolve_dict
try:
if within_columns_clause:
@@ -640,26 +757,30 @@ class SQLCompiler(Compiled):
# treat it like text()
util.warn_limited(
"Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.element))
- return self.process(
- element._text_clause
+ util.ellipses_string(element.element),
)
+ return self.process(element._text_clause)
else:
- kwargs['render_label_as_label'] = col
+ kwargs["render_label_as_label"] = col
return self.process(
- col, within_columns_clause=within_columns_clause, **kwargs)
-
- def visit_label(self, label,
- add_to_result_map=None,
- within_label_clause=False,
- within_columns_clause=False,
- render_label_as_label=None,
- **kw):
+ col, within_columns_clause=within_columns_clause, **kwargs
+ )
+
+ def visit_label(
+ self,
+ label,
+ add_to_result_map=None,
+ within_label_clause=False,
+ within_columns_clause=False,
+ render_label_as_label=None,
+ **kw
+ ):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
- render_label_with_as = (within_columns_clause and not
- within_label_clause)
+ render_label_with_as = (
+ within_columns_clause and not within_label_clause
+ )
render_label_only = render_label_as_label is label
if render_label_only or render_label_with_as:
@@ -673,27 +794,35 @@ class SQLCompiler(Compiled):
add_to_result_map(
labelname,
label.name,
- (label, labelname, ) + label._alt_names,
- label.type
+ (label, labelname) + label._alt_names,
+ label.type,
)
- return label.element._compiler_dispatch(
- self, within_columns_clause=True,
- within_label_clause=True, **kw) + \
- OPERATORS[operators.as_] + \
- self.preparer.format_label(label, labelname)
+ return (
+ label.element._compiler_dispatch(
+ self,
+ within_columns_clause=True,
+ within_label_clause=True,
+ **kw
+ )
+ + OPERATORS[operators.as_]
+ + self.preparer.format_label(label, labelname)
+ )
elif render_label_only:
return self.preparer.format_label(label, labelname)
else:
return label.element._compiler_dispatch(
- self, within_columns_clause=False, **kw)
+ self, within_columns_clause=False, **kw
+ )
def _fallback_column_name(self, column):
- raise exc.CompileError("Cannot compile Column object until "
- "its 'name' is assigned.")
+ raise exc.CompileError(
+ "Cannot compile Column object until " "its 'name' is assigned."
+ )
- def visit_column(self, column, add_to_result_map=None,
- include_table=True, **kwargs):
+ def visit_column(
+ self, column, add_to_result_map=None, include_table=True, **kwargs
+ ):
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
@@ -704,10 +833,7 @@ class SQLCompiler(Compiled):
if add_to_result_map is not None:
add_to_result_map(
- name,
- orig_name,
- (column, name, column.key),
- column.type
+ name, orig_name, (column, name, column.key), column.type
)
if is_literal:
@@ -721,17 +847,16 @@ class SQLCompiler(Compiled):
effective_schema = self.preparer.schema_for_object(table)
if effective_schema:
- schema_prefix = self.preparer.quote_schema(
- effective_schema) + '.'
+ schema_prefix = (
+ self.preparer.quote_schema(effective_schema) + "."
+ )
else:
- schema_prefix = ''
+ schema_prefix = ""
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
- return schema_prefix + \
- self.preparer.quote(tablename) + \
- "." + name
+ return schema_prefix + self.preparer.quote(tablename) + "." + name
def visit_collation(self, element, **kw):
return self.preparer.format_collation(element.collation)
@@ -743,17 +868,17 @@ class SQLCompiler(Compiled):
return index.name
def visit_typeclause(self, typeclause, **kw):
- kw['type_expression'] = typeclause
+ kw["type_expression"] = typeclause
return self.dialect.type_compiler.process(typeclause.type, **kw)
def post_process_text(self, text):
if self.preparer._double_percents:
- text = text.replace('%', '%%')
+ text = text.replace("%", "%%")
return text
def escape_literal_column(self, text):
if self.preparer._double_percents:
- text = text.replace('%', '%%')
+ text = text.replace("%", "%%")
return text
def visit_textclause(self, textclause, **kw):
@@ -771,30 +896,36 @@ class SQLCompiler(Compiled):
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
BIND_PARAMS.sub(
- do_bindparam,
- self.post_process_text(textclause.text))
+ do_bindparam, self.post_process_text(textclause.text)
+ ),
)
- def visit_text_as_from(self, taf,
- compound_index=None,
- asfrom=False,
- parens=True, **kw):
+ def visit_text_as_from(
+ self, taf, compound_index=None, asfrom=False, parens=True, **kw
+ ):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = toplevel or \
- (
- compound_index == 0 and entry.get(
- 'need_result_map_for_compound', False)
- ) or entry.get('need_result_map_for_nested', False)
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
if populate_result_map:
- self._ordered_columns = \
- self._textual_ordered_columns = taf.positional
+ self._ordered_columns = (
+ self._textual_ordered_columns
+ ) = taf.positional
for c in taf.column_args:
- self.process(c, within_columns_clause=True,
- add_to_result_map=self._add_to_result_map)
+ self.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=self._add_to_result_map,
+ )
text = self.process(taf.element, **kw)
if asfrom and parens:
@@ -802,17 +933,17 @@ class SQLCompiler(Compiled):
return text
def visit_null(self, expr, **kw):
- return 'NULL'
+ return "NULL"
def visit_true(self, expr, **kw):
if self.dialect.supports_native_boolean:
- return 'true'
+ return "true"
else:
return "1"
def visit_false(self, expr, **kw):
if self.dialect.supports_native_boolean:
- return 'false'
+ return "false"
else:
return "0"
@@ -823,25 +954,29 @@ class SQLCompiler(Compiled):
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
- s for s in
- (
- c._compiler_dispatch(self, **kw)
- for c in clauselist.clauses)
- if s)
+ s
+ for s in (
+ c._compiler_dispatch(self, **kw) for c in clauselist.clauses
+ )
+ if s
+ )
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
x += clause.value._compiler_dispatch(self, **kwargs) + " "
for cond, result in clause.whens:
- x += "WHEN " + cond._compiler_dispatch(
- self, **kwargs
- ) + " THEN " + result._compiler_dispatch(
- self, **kwargs) + " "
+ x += (
+ "WHEN "
+ + cond._compiler_dispatch(self, **kwargs)
+ + " THEN "
+ + result._compiler_dispatch(self, **kwargs)
+ + " "
+ )
if clause.else_ is not None:
- x += "ELSE " + clause.else_._compiler_dispatch(
- self, **kwargs
- ) + " "
+ x += (
+ "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
+ )
x += "END"
return x
@@ -849,79 +984,84 @@ class SQLCompiler(Compiled):
return type_coerce.typed_expression._compiler_dispatch(self, **kw)
def visit_cast(self, cast, **kwargs):
- return "CAST(%s AS %s)" % \
- (cast.clause._compiler_dispatch(self, **kwargs),
- cast.typeclause._compiler_dispatch(self, **kwargs))
+ return "CAST(%s AS %s)" % (
+ cast.clause._compiler_dispatch(self, **kwargs),
+ cast.typeclause._compiler_dispatch(self, **kwargs),
+ )
def _format_frame_clause(self, range_, **kw):
- return '%s AND %s' % (
+ return "%s AND %s" % (
"UNBOUNDED PRECEDING"
if range_[0] is elements.RANGE_UNBOUNDED
- else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT
- else "%s PRECEDING" % (
- self.process(elements.literal(abs(range_[0])), **kw), )
+ else "CURRENT ROW"
+ if range_[0] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[0])), **kw),)
if range_[0] < 0
- else "%s FOLLOWING" % (
- self.process(elements.literal(range_[0]), **kw), ),
-
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[0]), **kw),),
"UNBOUNDED FOLLOWING"
if range_[1] is elements.RANGE_UNBOUNDED
- else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT
- else "%s PRECEDING" % (
- self.process(elements.literal(abs(range_[1])), **kw), )
+ else "CURRENT ROW"
+ if range_[1] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[1])), **kw),)
if range_[1] < 0
- else "%s FOLLOWING" % (
- self.process(elements.literal(range_[1]), **kw), ),
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[1]), **kw),),
)
def visit_over(self, over, **kwargs):
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
- over.range_, **kwargs)
+ over.range_, **kwargs
+ )
elif over.rows:
range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
- over.rows, **kwargs)
+ over.rows, **kwargs
+ )
else:
range_ = None
return "%s OVER (%s)" % (
over.element._compiler_dispatch(self, **kwargs),
- ' '.join([
- '%s BY %s' % (
- word, clause._compiler_dispatch(self, **kwargs)
- )
- for word, clause in (
- ('PARTITION', over.partition_by),
- ('ORDER', over.order_by)
- )
- if clause is not None and len(clause)
- ] + ([range_] if range_ else [])
- )
+ " ".join(
+ [
+ "%s BY %s"
+ % (word, clause._compiler_dispatch(self, **kwargs))
+ for word, clause in (
+ ("PARTITION", over.partition_by),
+ ("ORDER", over.order_by),
+ )
+ if clause is not None and len(clause)
+ ]
+ + ([range_] if range_ else [])
+ ),
)
def visit_withingroup(self, withingroup, **kwargs):
return "%s WITHIN GROUP (ORDER BY %s)" % (
withingroup.element._compiler_dispatch(self, **kwargs),
- withingroup.order_by._compiler_dispatch(self, **kwargs)
+ withingroup.order_by._compiler_dispatch(self, **kwargs),
)
def visit_funcfilter(self, funcfilter, **kwargs):
return "%s FILTER (WHERE %s)" % (
funcfilter.func._compiler_dispatch(self, **kwargs),
- funcfilter.criterion._compiler_dispatch(self, **kwargs)
+ funcfilter.criterion._compiler_dispatch(self, **kwargs),
)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
- field, extract.expr._compiler_dispatch(self, **kwargs))
+ field,
+ extract.expr._compiler_dispatch(self, **kwargs),
+ )
def visit_function(self, func, add_to_result_map=None, **kwargs):
if add_to_result_map is not None:
- add_to_result_map(
- func.name, func.name, (), func.type
- )
+ add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
if disp:
@@ -933,51 +1073,63 @@ class SQLCompiler(Compiled):
name += "%(expr)s"
else:
name = func.name + "%(expr)s"
- return ".".join(list(func.packagenames) + [name]) % \
- {'expr': self.function_argspec(func, **kwargs)}
+ return ".".join(list(func.packagenames) + [name]) % {
+ "expr": self.function_argspec(func, **kwargs)
+ }
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
def visit_sequence(self, sequence, **kw):
raise NotImplementedError(
- "Dialect '%s' does not support sequence increments." %
- self.dialect.name
+ "Dialect '%s' does not support sequence increments."
+ % self.dialect.name
)
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
- def visit_compound_select(self, cs, asfrom=False,
- parens=True, compound_index=0, **kwargs):
+ def visit_compound_select(
+ self, cs, asfrom=False, parens=True, compound_index=0, **kwargs
+ ):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- need_result_map = toplevel or \
- (compound_index == 0
- and entry.get('need_result_map_for_compound', False))
+ need_result_map = toplevel or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
self.stack.append(
{
- 'correlate_froms': entry['correlate_froms'],
- 'asfrom_froms': entry['asfrom_froms'],
- 'selectable': cs,
- 'need_result_map_for_compound': need_result_map
- })
+ "correlate_froms": entry["correlate_froms"],
+ "asfrom_froms": entry["asfrom_froms"],
+ "selectable": cs,
+ "need_result_map_for_compound": need_result_map,
+ }
+ )
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
- (c._compiler_dispatch(self,
- asfrom=asfrom, parens=False,
- compound_index=i, **kwargs)
- for i, c in enumerate(cs.selects))
+ (
+ c._compiler_dispatch(
+ self,
+ asfrom=asfrom,
+ parens=False,
+ compound_index=i,
+ **kwargs
+ )
+ for i, c in enumerate(cs.selects)
+ )
)
text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
text += self.order_by_clause(cs, **kwargs)
- text += (cs._limit_clause is not None
- or cs._offset_clause is not None) and \
- self.limit_clause(cs, **kwargs) or ""
+ text += (
+ (cs._limit_clause is not None or cs._offset_clause is not None)
+ and self.limit_clause(cs, **kwargs)
+ or ""
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -990,8 +1142,10 @@ class SQLCompiler(Compiled):
def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
attrname = "visit_%s_%s%s" % (
- operator_.__name__, qualifier1,
- "_" + qualifier2 if qualifier2 else "")
+ operator_.__name__,
+ qualifier1,
+ "_" + qualifier2 if qualifier2 else "",
+ )
return getattr(self, attrname, None)
def visit_unary(self, unary, **kw):
@@ -999,51 +1153,63 @@ class SQLCompiler(Compiled):
if unary.modifier:
raise exc.CompileError(
"Unary expression does not support operator "
- "and modifier simultaneously")
+ "and modifier simultaneously"
+ )
disp = self._get_operator_dispatch(
- unary.operator, "unary", "operator")
+ unary.operator, "unary", "operator"
+ )
if disp:
return disp(unary, unary.operator, **kw)
else:
return self._generate_generic_unary_operator(
- unary, OPERATORS[unary.operator], **kw)
+ unary, OPERATORS[unary.operator], **kw
+ )
elif unary.modifier:
disp = self._get_operator_dispatch(
- unary.modifier, "unary", "modifier")
+ unary.modifier, "unary", "modifier"
+ )
if disp:
return disp(unary, unary.modifier, **kw)
else:
return self._generate_generic_unary_modifier(
- unary, OPERATORS[unary.modifier], **kw)
+ unary, OPERATORS[unary.modifier], **kw
+ )
else:
raise exc.CompileError(
- "Unary expression has no operator or modifier")
+ "Unary expression has no operator or modifier"
+ )
def visit_istrue_unary_operator(self, element, operator, **kw):
- if element._is_implicitly_boolean or \
- self.dialect.supports_native_boolean:
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
return self.process(element.element, **kw)
else:
return "%s = 1" % self.process(element.element, **kw)
def visit_isfalse_unary_operator(self, element, operator, **kw):
- if element._is_implicitly_boolean or \
- self.dialect.supports_native_boolean:
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
return "NOT %s" % self.process(element.element, **kw)
else:
return "%s = 0" % self.process(element.element, **kw)
def visit_notmatch_op_binary(self, binary, operator, **kw):
return "NOT %s" % self.visit_binary(
- binary, override_operator=operators.match_op)
+ binary, override_operator=operators.match_op
+ )
def _emit_empty_in_warning(self):
util.warn(
- 'The IN-predicate was invoked with an '
- 'empty sequence. This results in a '
- 'contradiction, which nonetheless can be '
- 'expensive to evaluate. Consider alternative '
- 'strategies for improved performance.')
+ "The IN-predicate was invoked with an "
+ "empty sequence. This results in a "
+ "contradiction, which nonetheless can be "
+ "expensive to evaluate. Consider alternative "
+ "strategies for improved performance."
+ )
def visit_empty_in_op_binary(self, binary, operator, **kw):
if self.dialect._use_static_in:
@@ -1063,18 +1229,21 @@ class SQLCompiler(Compiled):
def visit_empty_set_expr(self, element_types):
raise NotImplementedError(
- "Dialect '%s' does not support empty set expression." %
- self.dialect.name
+ "Dialect '%s' does not support empty set expression."
+ % self.dialect.name
)
- def visit_binary(self, binary, override_operator=None,
- eager_grouping=False, **kw):
+ def visit_binary(
+ self, binary, override_operator=None, eager_grouping=False, **kw
+ ):
# don't allow "? = ?" to render
- if self.ansi_bind_rules and \
- isinstance(binary.left, elements.BindParameter) and \
- isinstance(binary.right, elements.BindParameter):
- kw['literal_binds'] = True
+ if (
+ self.ansi_bind_rules
+ and isinstance(binary.left, elements.BindParameter)
+ and isinstance(binary.right, elements.BindParameter)
+ ):
+ kw["literal_binds"] = True
operator_ = override_operator or binary.operator
disp = self._get_operator_dispatch(operator_, "binary", None)
@@ -1093,36 +1262,50 @@ class SQLCompiler(Compiled):
def visit_mod_binary(self, binary, operator, **kw):
if self.preparer._double_percents:
- return self.process(binary.left, **kw) + " %% " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
else:
- return self.process(binary.left, **kw) + " % " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
def visit_custom_op_binary(self, element, operator, **kw):
- kw['eager_grouping'] = operator.eager_grouping
+ kw["eager_grouping"] = operator.eager_grouping
return self._generate_generic_binary(
- element, " " + operator.opstring + " ", **kw)
+ element, " " + operator.opstring + " ", **kw
+ )
def visit_custom_op_unary_operator(self, element, operator, **kw):
return self._generate_generic_unary_operator(
- element, operator.opstring + " ", **kw)
+ element, operator.opstring + " ", **kw
+ )
def visit_custom_op_unary_modifier(self, element, operator, **kw):
return self._generate_generic_unary_modifier(
- element, " " + operator.opstring, **kw)
+ element, " " + operator.opstring, **kw
+ )
def _generate_generic_binary(
- self, binary, opstring, eager_grouping=False, **kw):
+ self, binary, opstring, eager_grouping=False, **kw
+ ):
- _in_binary = kw.get('_in_binary', False)
+ _in_binary = kw.get("_in_binary", False)
- kw['_in_binary'] = True
- text = binary.left._compiler_dispatch(
- self, eager_grouping=eager_grouping, **kw) + \
- opstring + \
- binary.right._compiler_dispatch(
- self, eager_grouping=eager_grouping, **kw)
+ kw["_in_binary"] = True
+ text = (
+ binary.left._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ + opstring
+ + binary.right._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ )
if _in_binary and eager_grouping:
text = "(%s)" % text
@@ -1153,17 +1336,13 @@ class SQLCompiler(Compiled):
def visit_startswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
- binary.right = percent.__radd__(
- binary.right
- )
+ binary.right = percent.__radd__(binary.right)
return self.visit_like_op_binary(binary, operator, **kw)
def visit_notstartswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
- binary.right = percent.__radd__(
- binary.right
- )
+ binary.right = percent.__radd__(binary.right)
return self.visit_notlike_op_binary(binary, operator, **kw)
def visit_endswith_op_binary(self, binary, operator, **kw):
@@ -1182,98 +1361,105 @@ class SQLCompiler(Compiled):
escape = binary.modifiers.get("escape", None)
# TODO: use ternary here, not "and"/ "or"
- return '%s LIKE %s' % (
+ return "%s LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s NOT LIKE %s' % (
+ return "%s NOT LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) LIKE lower(%s)' % (
+ return "lower(%s) LIKE lower(%s)" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) NOT LIKE lower(%s)' % (
+ return "lower(%s) NOT LIKE lower(%s)" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " BETWEEN SYMMETRIC "
- if symmetric else " BETWEEN ", **kw)
+ binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
+ )
def visit_notbetween_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " NOT BETWEEN SYMMETRIC "
- if symmetric else " NOT BETWEEN ", **kw)
+ binary,
+ " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
+ **kw
+ )
- def visit_bindparam(self, bindparam, within_columns_clause=False,
- literal_binds=False,
- skip_bind_expression=False,
- **kwargs):
+ def visit_bindparam(
+ self,
+ bindparam,
+ within_columns_clause=False,
+ literal_binds=False,
+ skip_bind_expression=False,
+ **kwargs
+ ):
if not skip_bind_expression:
impl = bindparam.type.dialect_impl(self.dialect)
if impl._has_bind_expression:
bind_expression = impl.bind_expression(bindparam)
return self.process(
- bind_expression, skip_bind_expression=True,
+ bind_expression,
+ skip_bind_expression=True,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
**kwargs
)
- if literal_binds or \
- (within_columns_clause and
- self.ansi_bind_rules):
+ if literal_binds or (within_columns_clause and self.ansi_bind_rules):
if bindparam.value is None and bindparam.callable is None:
- raise exc.CompileError("Bind parameter '%s' without a "
- "renderable value not allowed here."
- % bindparam.key)
+ raise exc.CompileError(
+ "Bind parameter '%s' without a "
+ "renderable value not allowed here." % bindparam.key
+ )
return self.render_literal_bindparam(
- bindparam, within_columns_clause=True, **kwargs)
+ bindparam, within_columns_clause=True, **kwargs
+ )
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam:
- if (existing.unique or bindparam.unique) and \
- not existing.proxy_set.intersection(
- bindparam.proxy_set):
+ if (
+ existing.unique or bindparam.unique
+ ) and not existing.proxy_set.intersection(bindparam.proxy_set):
raise exc.CompileError(
"Bind parameter '%s' conflicts with "
- "unique bind parameter of the same name" %
- bindparam.key
+ "unique bind parameter of the same name"
+ % bindparam.key
)
elif existing._is_crud or bindparam._is_crud:
raise exc.CompileError(
@@ -1282,14 +1468,15 @@ class SQLCompiler(Compiled):
"clause of this "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
- "with insert() or update() (for example, 'b_%s')." %
- (bindparam.key, bindparam.key)
+ "with insert() or update() (for example, 'b_%s')."
+ % (bindparam.key, bindparam.key)
)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(
- name, expanding=bindparam.expanding, **kwargs)
+ name, expanding=bindparam.expanding, **kwargs
+ )
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.effective_value
@@ -1311,7 +1498,8 @@ class SQLCompiler(Compiled):
return processor(value)
else:
raise NotImplementedError(
- "Don't know how to literal-quote value %r" % value)
+ "Don't know how to literal-quote value %r" % value
+ )
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
@@ -1334,8 +1522,11 @@ class SQLCompiler(Compiled):
if len(anonname) > self.label_length - 6:
counter = self.truncated_names.get(ident_class, 1)
- truncname = anonname[0:max(self.label_length - 6, 0)] + \
- "_" + hex(counter)[2:]
+ truncname = (
+ anonname[0 : max(self.label_length - 6, 0)]
+ + "_"
+ + hex(counter)[2:]
+ )
self.truncated_names[ident_class] = counter + 1
else:
truncname = anonname
@@ -1346,13 +1537,14 @@ class SQLCompiler(Compiled):
return name % self.anon_map
def _process_anon(self, key):
- (ident, derived) = key.split(' ', 1)
+ (ident, derived) = key.split(" ", 1)
anonymous_counter = self.anon_map.get(derived, 1)
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
def bindparam_string(
- self, name, positional_names=None, expanding=False, **kw):
+ self, name, positional_names=None, expanding=False, **kw
+ ):
if self.positional:
if positional_names is not None:
positional_names.append(name)
@@ -1362,14 +1554,20 @@ class SQLCompiler(Compiled):
self.contains_expanding_parameters = True
return "([EXPANDING_%s])" % name
else:
- return self.bindtemplate % {'name': name}
-
- def visit_cte(self, cte, asfrom=False, ashint=False,
- fromhints=None, visiting_cte=None,
- **kwargs):
+ return self.bindtemplate % {"name": name}
+
+ def visit_cte(
+ self,
+ cte,
+ asfrom=False,
+ ashint=False,
+ fromhints=None,
+ visiting_cte=None,
+ **kwargs
+ ):
self._init_cte_state()
- kwargs['visiting_cte'] = cte
+ kwargs["visiting_cte"] = cte
if isinstance(cte.name, elements._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
@@ -1394,8 +1592,8 @@ class SQLCompiler(Compiled):
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
- "the same name: %r" %
- cte_name)
+ "the same name: %r" % cte_name
+ )
if asfrom or is_new_cte:
if cte._cte_alias is not None:
@@ -1403,7 +1601,8 @@ class SQLCompiler(Compiled):
cte_pre_alias_name = cte._cte_alias.name
if isinstance(cte_pre_alias_name, elements._truncated_label):
cte_pre_alias_name = self._truncated_identifier(
- "alias", cte_pre_alias_name)
+ "alias", cte_pre_alias_name
+ )
else:
pre_alias_cte = cte
cte_pre_alias_name = None
@@ -1412,11 +1611,17 @@ class SQLCompiler(Compiled):
self.ctes_by_name[cte_name] = cte
# look for embedded DML ctes and propagate autocommit
- if 'autocommit' in cte.element._execution_options and \
- 'autocommit' not in self.execution_options:
+ if (
+ "autocommit" in cte.element._execution_options
+ and "autocommit" not in self.execution_options
+ ):
self.execution_options = self.execution_options.union(
- {"autocommit":
- cte.element._execution_options['autocommit']})
+ {
+ "autocommit": cte.element._execution_options[
+ "autocommit"
+ ]
+ }
+ )
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -1432,25 +1637,30 @@ class SQLCompiler(Compiled):
col_source = cte.original.selects[0]
else:
assert False
- recur_cols = [c for c in
- util.unique_list(col_source.inner_columns)
- if c is not None]
-
- text += "(%s)" % (", ".join(
- self.preparer.format_column(ident)
- for ident in recur_cols))
+ recur_cols = [
+ c
+ for c in util.unique_list(col_source.inner_columns)
+ if c is not None
+ ]
+
+ text += "(%s)" % (
+ ", ".join(
+ self.preparer.format_column(ident)
+ for ident in recur_cols
+ )
+ )
if self.positional:
- kwargs['positional_names'] = self.cte_positional[cte] = []
+ kwargs["positional_names"] = self.cte_positional[cte] = []
- text += " AS \n" + \
- cte.original._compiler_dispatch(
- self, asfrom=True, **kwargs
- )
+ text += " AS \n" + cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
if cte._suffixes:
text += " " + self._generate_prefixes(
- cte, cte._suffixes, **kwargs)
+ cte, cte._suffixes, **kwargs
+ )
self.ctes[cte] = text
@@ -1467,9 +1677,15 @@ class SQLCompiler(Compiled):
else:
return self.preparer.format_alias(cte, cte_name)
- def visit_alias(self, alias, asfrom=False, ashint=False,
- iscrud=False,
- fromhints=None, **kwargs):
+ def visit_alias(
+ self,
+ alias,
+ asfrom=False,
+ ashint=False,
+ iscrud=False,
+ fromhints=None,
+ **kwargs
+ ):
if asfrom or ashint:
if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
@@ -1479,31 +1695,35 @@ class SQLCompiler(Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
- ret = alias.original._compiler_dispatch(self,
- asfrom=True, **kwargs) + \
- self.get_render_as_alias_suffix(
- self.preparer.format_alias(alias, alias_name))
+ ret = alias.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ ) + self.get_render_as_alias_suffix(
+ self.preparer.format_alias(alias, alias_name)
+ )
if fromhints and alias in fromhints:
- ret = self.format_from_hint_text(ret, alias,
- fromhints[alias], iscrud)
+ ret = self.format_from_hint_text(
+ ret, alias, fromhints[alias], iscrud
+ )
return ret
else:
return alias.original._compiler_dispatch(self, **kwargs)
def visit_lateral(self, lateral, **kw):
- kw['lateral'] = True
+ kw["lateral"] = True
return "LATERAL %s" % self.visit_alias(lateral, **kw)
def visit_tablesample(self, tablesample, asfrom=False, **kw):
text = "%s TABLESAMPLE %s" % (
self.visit_alias(tablesample, asfrom=True, **kw),
- tablesample._get_method()._compiler_dispatch(self, **kw))
+ tablesample._get_method()._compiler_dispatch(self, **kw),
+ )
if tablesample.seed is not None:
text += " REPEATABLE (%s)" % (
- tablesample.seed._compiler_dispatch(self, **kw))
+ tablesample.seed._compiler_dispatch(self, **kw)
+ )
return text
@@ -1513,22 +1733,27 @@ class SQLCompiler(Compiled):
def _add_to_result_map(self, keyname, name, objects, type_):
self._result_columns.append((keyname, name, objects, type_))
- def _label_select_column(self, select, column,
- populate_result_map,
- asfrom, column_clause_args,
- name=None,
- within_columns_clause=True):
+ def _label_select_column(
+ self,
+ select,
+ column,
+ populate_result_map,
+ asfrom,
+ column_clause_args,
+ name=None,
+ within_columns_clause=True,
+ ):
"""produce labeled columns present in a select()."""
impl = column.type.dialect_impl(self.dialect)
- if impl._has_column_expression and \
- populate_result_map:
+ if impl._has_column_expression and populate_result_map:
col_expr = impl.column_expression(column)
def add_to_result_map(keyname, name, objects, type_):
self._add_to_result_map(
- keyname, name,
- (column,) + objects, type_)
+ keyname, name, (column,) + objects, type_
+ )
+
else:
col_expr = column
if populate_result_map:
@@ -1541,58 +1766,56 @@ class SQLCompiler(Compiled):
elif isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
- col_expr,
- column.name,
- alt_names=(column.element,)
+ col_expr, column.name, alt_names=(column.element,)
)
else:
result_expr = col_expr
elif select is not None and name:
result_expr = _CompileLabel(
+ col_expr, name, alt_names=(column._key_label,)
+ )
+
+ elif (
+ asfrom
+ and isinstance(column, elements.ColumnClause)
+ and not column.is_literal
+ and column.table is not None
+ and not isinstance(column.table, selectable.Select)
+ ):
+ result_expr = _CompileLabel(
col_expr,
- name,
- alt_names=(column._key_label,)
- )
-
- elif \
- asfrom and \
- isinstance(column, elements.ColumnClause) and \
- not column.is_literal and \
- column.table is not None and \
- not isinstance(column.table, selectable.Select):
- result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ elements._as_truncated(column.name),
+ alt_names=(column.key,),
+ )
elif (
- not isinstance(column, elements.TextClause) and
- (
- not isinstance(column, elements.UnaryExpression) or
- column.wraps_column_expression
- ) and
- (
- not hasattr(column, 'name') or
- isinstance(column, functions.Function)
+ not isinstance(column, elements.TextClause)
+ and (
+ not isinstance(column, elements.UnaryExpression)
+ or column.wraps_column_expression
+ )
+ and (
+ not hasattr(column, "name")
+ or isinstance(column, functions.Function)
)
):
result_expr = _CompileLabel(col_expr, column.anon_label)
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
# assert isinstance(column, elements.ColumnClause)
- result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ result_expr = _CompileLabel(
+ col_expr,
+ elements._as_truncated(column.name),
+ alt_names=(column.key,),
+ )
else:
result_expr = col_expr
column_clause_args.update(
within_columns_clause=within_columns_clause,
- add_to_result_map=add_to_result_map
- )
- return result_expr._compiler_dispatch(
- self,
- **column_clause_args
+ add_to_result_map=add_to_result_map,
)
+ return result_expr._compiler_dispatch(self, **column_clause_args)
def format_from_hint_text(self, sqltext, table, hint, iscrud):
hinttext = self.get_from_hint_text(table, hint)
@@ -1631,8 +1854,11 @@ class SQLCompiler(Compiled):
newelem = cloned[element] = element._clone()
- if newelem.is_selectable and newelem._is_join and \
- isinstance(newelem.right, selectable.FromGrouping):
+ if (
+ newelem.is_selectable
+ and newelem._is_join
+ and isinstance(newelem.right, selectable.FromGrouping)
+ ):
newelem._reset_exported()
newelem.left = visit(newelem.left, **kw)
@@ -1640,8 +1866,8 @@ class SQLCompiler(Compiled):
right = visit(newelem.right, **kw)
selectable_ = selectable.Select(
- [right.element],
- use_labels=True).alias()
+ [right.element], use_labels=True
+ ).alias()
for c in selectable_.c:
c._key_label = c.key
@@ -1680,17 +1906,18 @@ class SQLCompiler(Compiled):
elif newelem._is_from_container:
# if we hit an Alias, CompoundSelect or ScalarSelect, put a
# marker in the stack.
- kw['transform_clue'] = 'select_container'
+ kw["transform_clue"] = "select_container"
newelem._copy_internals(clone=visit, **kw)
elif newelem.is_selectable and newelem._is_select:
- barrier_select = kw.get('transform_clue', None) == \
- 'select_container'
+ barrier_select = (
+ kw.get("transform_clue", None) == "select_container"
+ )
# if we're still descended from an
# Alias/CompoundSelect/ScalarSelect, we're
# in a FROM clause, so start with a new translate collection
if barrier_select:
column_translate.append({})
- kw['transform_clue'] = 'inside_select'
+ kw["transform_clue"] = "inside_select"
newelem._copy_internals(clone=visit, **kw)
if barrier_select:
del column_translate[-1]
@@ -1702,24 +1929,22 @@ class SQLCompiler(Compiled):
return visit(select)
def _transform_result_map_for_nested_joins(
- self, select, transformed_select):
- inner_col = dict((c._key_label, c) for
- c in transformed_select.inner_columns)
-
- d = dict(
- (inner_col[c._key_label], c)
- for c in select.inner_columns
+ self, select, transformed_select
+ ):
+ inner_col = dict(
+ (c._key_label, c) for c in transformed_select.inner_columns
)
+ d = dict((inner_col[c._key_label], c) for c in select.inner_columns)
+
self._result_columns = [
(key, name, tuple([d.get(col, col) for col in objs]), typ)
for key, name, objs, typ in self._result_columns
]
- _default_stack_entry = util.immutabledict([
- ('correlate_froms', frozenset()),
- ('asfrom_froms', frozenset())
- ])
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(self, select, asfrom, lateral=False):
# utility method to help external dialects
@@ -1729,72 +1954,88 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ implicit_correlate_froms=asfrom_froms,
+ )
return froms
- def visit_select(self, select, asfrom=False, parens=True,
- fromhints=None,
- compound_index=0,
- nested_join_translation=False,
- select_wraps_for=None,
- lateral=False,
- **kwargs):
-
- needs_nested_translation = \
- select.use_labels and \
- not nested_join_translation and \
- not self.stack and \
- not self.dialect.supports_right_nested_joins
+ def visit_select(
+ self,
+ select,
+ asfrom=False,
+ parens=True,
+ fromhints=None,
+ compound_index=0,
+ nested_join_translation=False,
+ select_wraps_for=None,
+ lateral=False,
+ **kwargs
+ ):
+
+ needs_nested_translation = (
+ select.use_labels
+ and not nested_join_translation
+ and not self.stack
+ and not self.dialect.supports_right_nested_joins
+ )
if needs_nested_translation:
transformed_select = self._transform_select_for_nested_joins(
- select)
+ select
+ )
text = self.visit_select(
- transformed_select, asfrom=asfrom, parens=parens,
+ transformed_select,
+ asfrom=asfrom,
+ parens=parens,
fromhints=fromhints,
compound_index=compound_index,
- nested_join_translation=True, **kwargs
+ nested_join_translation=True,
+ **kwargs
)
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = toplevel or \
- (
- compound_index == 0 and entry.get(
- 'need_result_map_for_compound', False)
- ) or entry.get('need_result_map_for_nested', False)
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
# this was first proposed as part of #3372; however, it is not
# reached in current tests and could possibly be an assertion
# instead.
- if not populate_result_map and 'add_to_result_map' in kwargs:
- del kwargs['add_to_result_map']
+ if not populate_result_map and "add_to_result_map" in kwargs:
+ del kwargs["add_to_result_map"]
if needs_nested_translation:
if populate_result_map:
self._transform_result_map_for_nested_joins(
- select, transformed_select)
+ select, transformed_select
+ )
return text
froms = self._setup_select_stack(select, entry, asfrom, lateral)
column_clause_args = kwargs.copy()
- column_clause_args.update({
- 'within_label_clause': False,
- 'within_columns_clause': False
- })
+ column_clause_args.update(
+ {"within_label_clause": False, "within_columns_clause": False}
+ )
text = "SELECT " # we're off to a good start !
@@ -1806,19 +2047,21 @@ class SQLCompiler(Compiled):
byfrom = None
if select._prefixes:
- text += self._generate_prefixes(
- select, select._prefixes, **kwargs)
+ text += self._generate_prefixes(select, select._prefixes, **kwargs)
text += self.get_select_precolumns(select, **kwargs)
# the actual list of columns to print in the SELECT column list.
inner_columns = [
- c for c in [
+ c
+ for c in [
self._label_select_column(
select,
column,
- populate_result_map, asfrom,
+ populate_result_map,
+ asfrom,
column_clause_args,
- name=name)
+ name=name,
+ )
for name, column in select._columns_plus_names
]
if c is not None
@@ -1831,8 +2074,11 @@ class SQLCompiler(Compiled):
translate = dict(
zip(
[name for (key, name) in select._columns_plus_names],
- [name for (key, name) in
- select_wraps_for._columns_plus_names])
+ [
+ name
+ for (key, name) in select_wraps_for._columns_plus_names
+ ],
+ )
)
self._result_columns = [
@@ -1841,13 +2087,14 @@ class SQLCompiler(Compiled):
]
text = self._compose_select_body(
- text, select, inner_columns, froms, byfrom, kwargs)
+ text, select, inner_columns, froms, byfrom, kwargs
+ )
if select._statement_hints:
per_dialect = [
- ht for (dialect_name, ht)
- in select._statement_hints
- if dialect_name in ('*', self.dialect.name)
+ ht
+ for (dialect_name, ht) in select._statement_hints
+ if dialect_name in ("*", self.dialect.name)
]
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
@@ -1857,7 +2104,8 @@ class SQLCompiler(Compiled):
if select._suffixes:
text += " " + self._generate_prefixes(
- select, select._suffixes, **kwargs)
+ select, select._suffixes, **kwargs
+ )
self.stack.pop(-1)
@@ -1867,60 +2115,73 @@ class SQLCompiler(Compiled):
return text
def _setup_select_hints(self, select):
- byfrom = dict([
- (from_, hinttext % {
- 'name': from_._compiler_dispatch(
- self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ byfrom = dict(
+ [
+ (
+ from_,
+ hinttext
+ % {"name": from_._compiler_dispatch(self, ashint=True)},
+ )
+ for (from_, dialect), hinttext in select._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
hint_text = self.get_select_hint_text(byfrom)
return hint_text, byfrom
def _setup_select_stack(self, select, entry, asfrom, lateral):
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ implicit_correlate_froms=asfrom_froms,
+ )
new_correlate_froms = set(selectable._from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
new_entry = {
- 'asfrom_froms': new_correlate_froms,
- 'correlate_froms': all_correlate_froms,
- 'selectable': select,
+ "asfrom_froms": new_correlate_froms,
+ "correlate_froms": all_correlate_froms,
+ "selectable": select,
}
self.stack.append(new_entry)
return froms
def _compose_select_body(
- self, text, select, inner_columns, froms, byfrom, kwargs):
- text += ', '.join(inner_columns)
+ self, text, select, inner_columns, froms, byfrom, kwargs
+ ):
+ text += ", ".join(inner_columns)
if froms:
text += " \nFROM "
if select._hints:
- text += ', '.join(
- [f._compiler_dispatch(self, asfrom=True,
- fromhints=byfrom, **kwargs)
- for f in froms])
+ text += ", ".join(
+ [
+ f._compiler_dispatch(
+ self, asfrom=True, fromhints=byfrom, **kwargs
+ )
+ for f in froms
+ ]
+ )
else:
- text += ', '.join(
- [f._compiler_dispatch(self, asfrom=True, **kwargs)
- for f in froms])
+ text += ", ".join(
+ [
+ f._compiler_dispatch(self, asfrom=True, **kwargs)
+ for f in froms
+ ]
+ )
else:
text += self.default_from()
@@ -1940,8 +2201,10 @@ class SQLCompiler(Compiled):
if select._order_by_clause.clauses:
text += self.order_by_clause(select, **kwargs)
- if (select._limit_clause is not None or
- select._offset_clause is not None):
+ if (
+ select._limit_clause is not None
+ or select._offset_clause is not None
+ ):
text += self.limit_clause(select, **kwargs)
if select._for_update_arg is not None:
@@ -1953,8 +2216,7 @@ class SQLCompiler(Compiled):
clause = " ".join(
prefix._compiler_dispatch(self, **kw)
for prefix, dialect_name in prefixes
- if dialect_name is None or
- dialect_name == self.dialect.name
+ if dialect_name is None or dialect_name == self.dialect.name
)
if clause:
clause += " "
@@ -1962,14 +2224,12 @@ class SQLCompiler(Compiled):
def _render_cte_clause(self):
if self.positional:
- self.positiontup = sum([
- self.cte_positional[cte]
- for cte in self.ctes], []) + \
- self.positiontup
+ self.positiontup = (
+ sum([self.cte_positional[cte] for cte in self.ctes], [])
+ + self.positiontup
+ )
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
- cte_text += ", \n".join(
- [txt for txt in self.ctes.values()]
- )
+ cte_text += ", \n".join([txt for txt in self.ctes.values()])
cte_text += "\n "
return cte_text
@@ -2010,7 +2270,8 @@ class SQLCompiler(Compiled):
def returning_clause(self, stmt, returning_cols):
raise exc.CompileError(
"RETURNING is not supported by this "
- "dialect's statement compiler.")
+ "dialect's statement compiler."
+ )
def limit_clause(self, select, **kw):
text = ""
@@ -2022,19 +2283,31 @@ class SQLCompiler(Compiled):
text += " OFFSET " + self.process(select._offset_clause, **kw)
return text
- def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
- fromhints=None, use_schema=True, **kwargs):
+ def visit_table(
+ self,
+ table,
+ asfrom=False,
+ iscrud=False,
+ ashint=False,
+ fromhints=None,
+ use_schema=True,
+ **kwargs
+ ):
if asfrom or ashint:
effective_schema = self.preparer.schema_for_object(table)
if use_schema and effective_schema:
- ret = self.preparer.quote_schema(effective_schema) + \
- "." + self.preparer.quote(table.name)
+ ret = (
+ self.preparer.quote_schema(effective_schema)
+ + "."
+ + self.preparer.quote(table.name)
+ )
else:
ret = self.preparer.quote(table.name)
if fromhints and table in fromhints:
- ret = self.format_from_hint_text(ret, table,
- fromhints[table], iscrud)
+ ret = self.format_from_hint_text(
+ ret, table, fromhints[table], iscrud
+ )
return ret
else:
return ""
@@ -2047,26 +2320,24 @@ class SQLCompiler(Compiled):
else:
join_type = " JOIN "
return (
- join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
- join_type +
- join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
- " ON " +
- join.onclause._compiler_dispatch(self, **kwargs)
+ join.left._compiler_dispatch(self, asfrom=True, **kwargs)
+ + join_type
+ + join.right._compiler_dispatch(self, asfrom=True, **kwargs)
+ + " ON "
+ + join.onclause._compiler_dispatch(self, **kwargs)
)
def _setup_crud_hints(self, stmt, table_text):
- dialect_hints = dict([
- (table, hint_text)
- for (table, dialect), hint_text in
- stmt._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ dialect_hints = dict(
+ [
+ (table, hint_text)
+ for (table, dialect), hint_text in stmt._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
if stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
- table_text,
- stmt.table,
- dialect_hints[stmt.table],
- True
+ table_text, stmt.table, dialect_hints[stmt.table], True
)
return dialect_hints, table_text
@@ -2074,28 +2345,35 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
self.stack.append(
- {'correlate_froms': set(),
- "asfrom_froms": set(),
- "selectable": insert_stmt})
+ {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": insert_stmt,
+ }
+ )
crud_params = crud._setup_crud_params(
- self, insert_stmt, crud.ISINSERT, **kw)
+ self, insert_stmt, crud.ISINSERT, **kw
+ )
- if not crud_params and \
- not self.dialect.supports_default_values and \
- not self.dialect.supports_empty_insert:
- raise exc.CompileError("The '%s' dialect with current database "
- "version settings does not support empty "
- "inserts." %
- self.dialect.name)
+ if (
+ not crud_params
+ and not self.dialect.supports_default_values
+ and not self.dialect.supports_empty_insert
+ ):
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support empty "
+ "inserts." % self.dialect.name
+ )
if insert_stmt._has_multi_parameters:
if not self.dialect.supports_multivalues_insert:
raise exc.CompileError(
"The '%s' dialect with current database "
"version settings does not support "
- "in-place multirow inserts." %
- self.dialect.name)
+ "in-place multirow inserts." % self.dialect.name
+ )
crud_params_single = crud_params[0]
else:
crud_params_single = crud_params
@@ -2106,27 +2384,31 @@ class SQLCompiler(Compiled):
text = "INSERT "
if insert_stmt._prefixes:
- text += self._generate_prefixes(insert_stmt,
- insert_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ insert_stmt, insert_stmt._prefixes, **kw
+ )
text += "INTO "
table_text = preparer.format_table(insert_stmt.table)
if insert_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- insert_stmt, table_text)
+ insert_stmt, table_text
+ )
else:
dialect_hints = None
text += table_text
if crud_params_single or not supports_default_values:
- text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in crud_params_single])
+ text += " (%s)" % ", ".join(
+ [preparer.format_column(c[0]) for c in crud_params_single]
+ )
if self.returning or insert_stmt._returning:
returning_clause = self.returning_clause(
- insert_stmt, self.returning or insert_stmt._returning)
+ insert_stmt, self.returning or insert_stmt._returning
+ )
if self.returning_precedes_values:
text += " " + returning_clause
@@ -2145,19 +2427,17 @@ class SQLCompiler(Compiled):
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
- "(%s)" % (
- ', '.join(c[1] for c in crud_param_set)
- )
+ "(%s)" % (", ".join(c[1] for c in crud_param_set))
for crud_param_set in crud_params
)
)
else:
- text += " VALUES (%s)" % \
- ', '.join([c[1] for c in crud_params])
+ text += " VALUES (%s)" % ", ".join([c[1] for c in crud_params])
if insert_stmt._post_values_clause is not None:
post_values_clause = self.process(
- insert_stmt._post_values_clause, **kw)
+ insert_stmt._post_values_clause, **kw
+ )
if post_values_clause:
text += " " + post_values_clause
@@ -2178,21 +2458,19 @@ class SQLCompiler(Compiled):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
- def update_tables_clause(self, update_stmt, from_table,
- extra_froms, **kw):
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
MySQL overrides this.
"""
- kw['asfrom'] = True
+ kw["asfrom"] = True
return from_table._compiler_dispatch(self, iscrud=True, **kw)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
@@ -2201,7 +2479,8 @@ class SQLCompiler(Compiled):
"""
raise NotImplementedError(
"This backend does not support multiple-table "
- "criteria within UPDATE")
+ "criteria within UPDATE"
+ )
def visit_update(self, update_stmt, asfrom=False, **kw):
toplevel = not self.stack
@@ -2221,49 +2500,61 @@ class SQLCompiler(Compiled):
correlate_froms = {update_stmt.table}
self.stack.append(
- {'correlate_froms': correlate_froms,
- "asfrom_froms": correlate_froms,
- "selectable": update_stmt})
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": update_stmt,
+ }
+ )
text = "UPDATE "
if update_stmt._prefixes:
- text += self._generate_prefixes(update_stmt,
- update_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ update_stmt, update_stmt._prefixes, **kw
+ )
- table_text = self.update_tables_clause(update_stmt, update_stmt.table,
- render_extra_froms, **kw)
+ table_text = self.update_tables_clause(
+ update_stmt, update_stmt.table, render_extra_froms, **kw
+ )
crud_params = crud._setup_crud_params(
- self, update_stmt, crud.ISUPDATE, **kw)
+ self, update_stmt, crud.ISUPDATE, **kw
+ )
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- update_stmt, table_text)
+ update_stmt, table_text
+ )
else:
dialect_hints = None
text += table_text
- text += ' SET '
- include_table = is_multitable and \
- self.render_table_with_column_in_update_from
- text += ', '.join(
- c[0]._compiler_dispatch(self,
- include_table=include_table) +
- '=' + c[1] for c in crud_params
+ text += " SET "
+ include_table = (
+ is_multitable and self.render_table_with_column_in_update_from
+ )
+ text += ", ".join(
+ c[0]._compiler_dispatch(self, include_table=include_table)
+ + "="
+ + c[1]
+ for c in crud_params
)
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning or update_stmt._returning)
+ update_stmt, self.returning or update_stmt._returning
+ )
if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
render_extra_froms,
- dialect_hints, **kw)
+ dialect_hints,
+ **kw
+ )
if extra_from_text:
text += " " + extra_from_text
@@ -2276,10 +2567,12 @@ class SQLCompiler(Compiled):
if limit_clause:
text += " " + limit_clause
- if (self.returning or update_stmt._returning) and \
- not self.returning_precedes_values:
+ if (
+ self.returning or update_stmt._returning
+ ) and not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning or update_stmt._returning)
+ update_stmt, self.returning or update_stmt._returning
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -2295,9 +2588,9 @@ class SQLCompiler(Compiled):
def _key_getters_for_crud_column(self):
return crud._key_getters_for_crud_column(self, self.statement)
- def delete_extra_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints, **kw):
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Provide a hook to override the generation of an
DELETE..FROM clause.
@@ -2308,10 +2601,10 @@ class SQLCompiler(Compiled):
"""
raise NotImplementedError(
"This backend does not support multiple-table "
- "criteria within DELETE")
+ "criteria within DELETE"
+ )
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
def visit_delete(self, delete_stmt, asfrom=False, **kw):
@@ -2322,23 +2615,30 @@ class SQLCompiler(Compiled):
extra_froms = delete_stmt._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
- self.stack.append({'correlate_froms': correlate_froms,
- "asfrom_froms": correlate_froms,
- "selectable": delete_stmt})
+ self.stack.append(
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": delete_stmt,
+ }
+ )
text = "DELETE "
if delete_stmt._prefixes:
- text += self._generate_prefixes(delete_stmt,
- delete_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ delete_stmt, delete_stmt._prefixes, **kw
+ )
text += "FROM "
- table_text = self.delete_table_clause(delete_stmt, delete_stmt.table,
- extra_froms)
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
if delete_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- delete_stmt, table_text)
+ delete_stmt, table_text
+ )
else:
dialect_hints = None
@@ -2347,14 +2647,17 @@ class SQLCompiler(Compiled):
if delete_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning
+ )
if extra_froms:
extra_from_text = self.delete_extra_from_clause(
delete_stmt,
delete_stmt.table,
extra_froms,
- dialect_hints, **kw)
+ dialect_hints,
+ **kw
+ )
if extra_from_text:
text += " " + extra_from_text
@@ -2365,7 +2668,8 @@ class SQLCompiler(Compiled):
if delete_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -2381,12 +2685,14 @@ class SQLCompiler(Compiled):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
- return "ROLLBACK TO SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
def visit_release_savepoint(self, savepoint_stmt):
- return "RELEASE SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
class StrSQLCompiler(SQLCompiler):
@@ -2403,7 +2709,7 @@ class StrSQLCompiler(SQLCompiler):
def visit_getitem_binary(self, binary, operator, **kw):
return "%s[%s]" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def visit_json_getitem_op_binary(self, binary, operator, **kw):
@@ -2421,29 +2727,26 @@ class StrSQLCompiler(SQLCompiler):
for c in elements._select_iterables(returning_cols)
]
- return 'RETURNING ' + ', '.join(columns)
+ return "RETURNING " + ", ".join(columns)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
- def delete_extra_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return ', ' + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return ", " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
class DDLCompiler(Compiled):
-
@util.memoized_property
def sql_compiler(self):
return self.dialect.statement_compiler(self.dialect, None)
@@ -2464,13 +2767,13 @@ class DDLCompiler(Compiled):
preparer = self.preparer
path = preparer.format_table_seq(ddl.target)
if len(path) == 1:
- table, sch = path[0], ''
+ table, sch = path[0], ""
else:
table, sch = path[-1], path[0]
- context.setdefault('table', table)
- context.setdefault('schema', sch)
- context.setdefault('fullname', preparer.format_table(ddl.target))
+ context.setdefault("table", table)
+ context.setdefault("schema", sch)
+ context.setdefault("fullname", preparer.format_table(ddl.target))
return self.sql_compiler.post_process_text(ddl.statement % context)
@@ -2507,9 +2810,9 @@ class DDLCompiler(Compiled):
for create_column in create.columns:
column = create_column.element
try:
- processed = self.process(create_column,
- first_pk=column.primary_key
- and not first_pk)
+ processed = self.process(
+ create_column, first_pk=column.primary_key and not first_pk
+ )
if processed is not None:
text += separator
separator = ", \n"
@@ -2519,13 +2822,15 @@ class DDLCompiler(Compiled):
except exc.CompileError as ce:
util.raise_from_cause(
exc.CompileError(
- util.u("(in table '%s', column '%s'): %s") %
- (table.description, column.name, ce.args[0])
- ))
+ util.u("(in table '%s', column '%s'): %s")
+ % (table.description, column.name, ce.args[0])
+ )
+ )
const = self.create_table_constraints(
- table, _include_foreign_key_constraints= # noqa
- create.include_foreign_key_constraints)
+ table,
+ _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
+ )
if const:
text += separator + "\t" + const
@@ -2538,20 +2843,18 @@ class DDLCompiler(Compiled):
if column.system:
return None
- text = self.get_column_specification(
- column,
- first_pk=first_pk
+ text = self.get_column_specification(column, first_pk=first_pk)
+ const = " ".join(
+ self.process(constraint) for constraint in column.constraints
)
- const = " ".join(self.process(constraint)
- for constraint in column.constraints)
if const:
text += " " + const
return text
def create_table_constraints(
- self, table,
- _include_foreign_key_constraints=None):
+ self, table, _include_foreign_key_constraints=None
+ ):
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
@@ -2565,21 +2868,29 @@ class DDLCompiler(Compiled):
else:
omit_fkcs = set()
- constraints.extend([c for c in table._sorted_constraints
- if c is not table.primary_key and
- c not in omit_fkcs])
+ constraints.extend(
+ [
+ c
+ for c in table._sorted_constraints
+ if c is not table.primary_key and c not in omit_fkcs
+ ]
+ )
return ", \n\t".join(
- p for p in
- (self.process(constraint)
+ p
+ for p in (
+ self.process(constraint)
for constraint in constraints
if (
- constraint._create_rule is None or
- constraint._create_rule(self))
+ constraint._create_rule is None
+ or constraint._create_rule(self)
+ )
and (
- not self.dialect.supports_alter or
- not getattr(constraint, 'use_alter', False)
- )) if p is not None
+ not self.dialect.supports_alter
+ or not getattr(constraint, "use_alter", False)
+ )
+ )
+ if p is not None
)
def visit_drop_table(self, drop):
@@ -2590,34 +2901,38 @@ class DDLCompiler(Compiled):
def _verify_index_table(self, index):
if index.table is None:
- raise exc.CompileError("Index '%s' is not associated "
- "with any table." % index.name)
+ raise exc.CompileError(
+ "Index '%s' is not associated " "with any table." % index.name
+ )
- def visit_create_index(self, create, include_schema=False,
- include_table_schema=True):
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True
+ ):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
text = "CREATE "
if index.unique:
text += "UNIQUE "
- text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=include_schema),
- preparer.format_table(index.table,
- use_schema=include_table_schema),
- ', '.join(
- self.sql_compiler.process(
- expr, include_table=False, literal_binds=True) for
- expr in index.expressions)
- )
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(
+ index.table, use_schema=include_table_schema
+ ),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
return text
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX " + self._prepared_index_name(
- index, include_schema=True)
+ index, include_schema=True
+ )
def _prepared_index_name(self, index, include_schema=False):
if index.table is not None:
@@ -2638,35 +2953,41 @@ class DDLCompiler(Compiled):
def visit_add_constraint(self, create):
return "ALTER TABLE %s ADD %s" % (
self.preparer.format_table(create.element.table),
- self.process(create.element)
+ self.process(create.element),
)
def visit_set_table_comment(self, create):
return "COMMENT ON TABLE %s IS %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_drop_table_comment(self, drop):
- return "COMMENT ON TABLE %s IS NULL" % \
- self.preparer.format_table(drop.element)
+ return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
+ drop.element
+ )
def visit_set_column_comment(self, create):
return "COMMENT ON COLUMN %s IS %s" % (
self.preparer.format_column(
- create.element, use_table=True, use_schema=True),
+ create.element, use_table=True, use_schema=True
+ ),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_drop_column_comment(self, drop):
- return "COMMENT ON COLUMN %s IS NULL" % \
- self.preparer.format_column(drop.element, use_table=True)
+ return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
+ drop.element, use_table=True
+ )
def visit_create_sequence(self, create):
- text = "CREATE SEQUENCE %s" % \
- self.preparer.format_sequence(create.element)
+ text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
if create.element.increment is not None:
text += " INCREMENT BY %d" % create.element.increment
if create.element.start is not None:
@@ -2688,8 +3009,7 @@ class DDLCompiler(Compiled):
return text
def visit_drop_sequence(self, drop):
- return "DROP SEQUENCE %s" % \
- self.preparer.format_sequence(drop.element)
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
def visit_drop_constraint(self, drop):
constraint = drop.element
@@ -2701,17 +3021,22 @@ class DDLCompiler(Compiled):
if formatted_name is None:
raise exc.CompileError(
"Can't emit DROP CONSTRAINT for constraint %r; "
- "it has no name" % drop.element)
+ "it has no name" % drop.element
+ )
return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
self.preparer.format_table(drop.element.table),
formatted_name,
- drop.cascade and " CASCADE" or ""
+ drop.cascade and " CASCADE" or "",
)
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -2721,19 +3046,21 @@ class DDLCompiler(Compiled):
return colspec
def create_table_suffix(self, table):
- return ''
+ return ""
def post_create_table(self, table):
- return ''
+ return ""
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, util.string_types):
return self.sql_compiler.render_literal_value(
- column.server_default.arg, sqltypes.STRINGTYPE)
+ column.server_default.arg, sqltypes.STRINGTYPE
+ )
else:
return self.sql_compiler.process(
- column.server_default.arg, literal_binds=True)
+ column.server_default.arg, literal_binds=True
+ )
else:
return None
@@ -2743,9 +3070,9 @@ class DDLCompiler(Compiled):
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
- text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2755,25 +3082,29 @@ class DDLCompiler(Compiled):
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
- text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
text += "PRIMARY KEY "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in (constraint.columns_autoinc_first
- if constraint._implicit_generated
- else constraint.columns))
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name)
+ for c in (
+ constraint.columns_autoinc_first
+ if constraint._implicit_generated
+ else constraint.columns
+ )
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2786,12 +3117,15 @@ class DDLCompiler(Compiled):
text += "CONSTRAINT %s " % formatted_name
remote_table = list(constraint.elements)[0].column.table
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join(preparer.quote(f.parent.name)
- for f in constraint.elements),
+ ", ".join(
+ preparer.quote(f.parent.name) for f in constraint.elements
+ ),
self.define_constraint_remote_table(
- constraint, remote_table, preparer),
- ', '.join(preparer.quote(f.column.name)
- for f in constraint.elements)
+ constraint, remote_table, preparer
+ ),
+ ", ".join(
+ preparer.quote(f.column.name) for f in constraint.elements
+ ),
)
text += self.define_constraint_match(constraint)
text += self.define_constraint_cascades(constraint)
@@ -2805,14 +3139,14 @@ class DDLCompiler(Compiled):
def visit_unique_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
text += "CONSTRAINT %s " % formatted_name
text += "UNIQUE (%s)" % (
- ', '.join(self.preparer.quote(c.name)
- for c in constraint))
+ ", ".join(self.preparer.quote(c.name) for c in constraint)
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2843,7 +3177,6 @@ class DDLCompiler(Compiled):
class GenericTypeCompiler(TypeCompiler):
-
def visit_FLOAT(self, type_, **kw):
return "FLOAT"
@@ -2854,23 +3187,23 @@ class GenericTypeCompiler(TypeCompiler):
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
- return "NUMERIC(%(precision)s)" % \
- {'precision': type_.precision}
+ return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
else:
- return "NUMERIC(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ return "NUMERIC(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
- return "DECIMAL(%(precision)s)" % \
- {'precision': type_.precision}
+ return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
else:
- return "DECIMAL(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ return "DECIMAL(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
def visit_INTEGER(self, type_, **kw):
return "INTEGER"
@@ -2882,7 +3215,7 @@ class GenericTypeCompiler(TypeCompiler):
return "BIGINT"
def visit_TIMESTAMP(self, type_, **kw):
- return 'TIMESTAMP'
+ return "TIMESTAMP"
def visit_DATETIME(self, type_, **kw):
return "DATETIME"
@@ -2984,9 +3317,11 @@ class GenericTypeCompiler(TypeCompiler):
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
- raise exc.CompileError("Can't generate DDL for %r; "
- "did you forget to specify a "
- "type on this Column?" % type_)
+ raise exc.CompileError(
+ "Can't generate DDL for %r; "
+ "did you forget to specify a "
+ "type on this Column?" % type_
+ )
def visit_type_decorator(self, type_, **kw):
return self.process(type_.type_engine(self.dialect), **kw)
@@ -3018,9 +3353,15 @@ class IdentifierPreparer(object):
schema_for_object = schema._schema_getter(None)
- def __init__(self, dialect, initial_quote='"',
- final_quote=None, escape_quote='"',
- quote_case_sensitive_collations=True, omit_schema=False):
+ def __init__(
+ self,
+ dialect,
+ initial_quote='"',
+ final_quote=None,
+ escape_quote='"',
+ quote_case_sensitive_collations=True,
+ omit_schema=False,
+ ):
"""Construct a new ``IdentifierPreparer`` object.
initial_quote
@@ -3043,7 +3384,10 @@ class IdentifierPreparer(object):
self.omit_schema = omit_schema
self.quote_case_sensitive_collations = quote_case_sensitive_collations
self._strings = {}
- self._double_percents = self.dialect.paramstyle in ('format', 'pyformat')
+ self._double_percents = self.dialect.paramstyle in (
+ "format",
+ "pyformat",
+ )
def _with_schema_translate(self, schema_translate_map):
prep = self.__class__.__new__(self.__class__)
@@ -3060,7 +3404,7 @@ class IdentifierPreparer(object):
value = value.replace(self.escape_quote, self.escape_to_quote)
if self._double_percents:
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
return value
def _unescape_identifier(self, value):
@@ -3079,17 +3423,21 @@ class IdentifierPreparer(object):
quoting behavior.
"""
- return self.initial_quote + \
- self._escape_identifier(value) + \
- self.final_quote
+ return (
+ self.initial_quote
+ + self._escape_identifier(value)
+ + self.final_quote
+ )
def _requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
- return (lc_value in self.reserved_words
- or value[0] in self.illegal_initial_characters
- or not self.legal_characters.match(util.text_type(value))
- or (lc_value != value))
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ or (lc_value != value)
+ )
def quote_schema(self, schema, force=None):
"""Conditionally quote a schema.
@@ -3135,8 +3483,11 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(sequence)
- if (not self.omit_schema and use_schema and
- effective_schema is not None):
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
name = self.quote_schema(effective_schema) + "." + name
return name
@@ -3159,7 +3510,8 @@ class IdentifierPreparer(object):
def format_constraint(self, naming, constraint):
if isinstance(constraint.name, elements._defer_name):
name = naming._constraint_name_for_table(
- constraint, constraint.table)
+ constraint, constraint.table
+ )
if name is None:
if isinstance(constraint.name, elements._defer_none_name):
@@ -3170,14 +3522,15 @@ class IdentifierPreparer(object):
name = constraint.name
if isinstance(name, elements._truncated_label):
- if constraint.__visit_name__ == 'index':
- max_ = self.dialect.max_index_name_length or \
- self.dialect.max_identifier_length
+ if constraint.__visit_name__ == "index":
+ max_ = (
+ self.dialect.max_index_name_length
+ or self.dialect.max_identifier_length
+ )
else:
max_ = self.dialect.max_identifier_length
if len(name) > max_:
- name = name[0:max_ - 8] + \
- "_" + util.md5_hex(name)[-4:]
+ name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
else:
self.dialect.validate_identifier(name)
@@ -3195,8 +3548,7 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(table)
- if not self.omit_schema and use_schema \
- and effective_schema:
+ if not self.omit_schema and use_schema and effective_schema:
result = self.quote_schema(effective_schema) + "." + result
return result
@@ -3205,17 +3557,27 @@ class IdentifierPreparer(object):
return self.quote(name, quote)
- def format_column(self, column, use_table=False,
- name=None, table_name=None, use_schema=False):
+ def format_column(
+ self,
+ column,
+ use_table=False,
+ name=None,
+ table_name=None,
+ use_schema=False,
+ ):
"""Prepare a quoted column name."""
if name is None:
name = column.name
- if not getattr(column, 'is_literal', False):
+ if not getattr(column, "is_literal", False):
if use_table:
- return self.format_table(
- column.table, use_schema=use_schema,
- name=table_name) + "." + self.quote(name)
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + self.quote(name)
+ )
else:
return self.quote(name)
else:
@@ -3223,9 +3585,13 @@ class IdentifierPreparer(object):
# which shouldn't get quoted
if use_table:
- return self.format_table(
- column.table, use_schema=use_schema,
- name=table_name) + '.' + name
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + name
+ )
else:
return name
@@ -3238,31 +3604,37 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(table)
- if not self.omit_schema and use_schema and \
- effective_schema:
- return (self.quote_schema(effective_schema),
- self.format_table(table, use_schema=False))
+ if not self.omit_schema and use_schema and effective_schema:
+ return (
+ self.quote_schema(effective_schema),
+ self.format_table(table, use_schema=False),
+ )
else:
- return (self.format_table(table, use_schema=False), )
+ return (self.format_table(table, use_schema=False),)
@util.memoized_property
def _r_identifiers(self):
- initial, final, escaped_final = \
- [re.escape(s) for s in
- (self.initial_quote, self.final_quote,
- self._escape_identifier(self.final_quote))]
+ initial, final, escaped_final = [
+ re.escape(s)
+ for s in (
+ self.initial_quote,
+ self.final_quote,
+ self._escape_identifier(self.final_quote),
+ )
+ ]
r = re.compile(
- r'(?:'
- r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
- r'|([^\.]+))(?=\.|$))+' %
- {'initial': initial,
- 'final': final,
- 'escaped': escaped_final})
+ r"(?:"
+ r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
+ r"|([^\.]+))(?=\.|$))+"
+ % {"initial": initial, "final": final, "escaped": escaped_final}
+ )
return r
def unformat_identifiers(self, identifiers):
"""Unpack 'schema.table.column'-like strings into components."""
r = self._r_identifiers
- return [self._unescape_identifier(i)
- for i in [a or b for a, b in r.findall(identifiers)]]
+ return [
+ self._unescape_identifier(i)
+ for i in [a or b for a, b in r.findall(identifiers)]
+ ]
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 999d48a55..602b91a25 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -15,7 +15,9 @@ from . import dml
from . import elements
import operator
-REQUIRED = util.symbol('REQUIRED', """
+REQUIRED = util.symbol(
+ "REQUIRED",
+ """
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
to :meth:`.Connection.execute`.
@@ -24,11 +26,12 @@ This symbol is typically used when a :func:`.expression.insert`
or :func:`.expression.update` statement is compiled without parameter
values present.
-""")
+""",
+)
-ISINSERT = util.symbol('ISINSERT')
-ISUPDATE = util.symbol('ISUPDATE')
-ISDELETE = util.symbol('ISDELETE')
+ISINSERT = util.symbol("ISINSERT")
+ISUPDATE = util.symbol("ISUPDATE")
+ISDELETE = util.symbol("ISDELETE")
def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
@@ -82,8 +85,7 @@ def _get_crud_params(compiler, stmt, **kw):
# compiled params - return binds for all columns
if compiler.column_keys is None and stmt.parameters is None:
return [
- (c, _create_bind_param(
- compiler, c, None, required=True))
+ (c, _create_bind_param(compiler, c, None, required=True))
for c in stmt.table.columns
]
@@ -95,26 +97,28 @@ def _get_crud_params(compiler, stmt, **kw):
# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
- _column_as_key, _getattr_col_key, _col_bind_name = \
- _key_getters_for_crud_column(compiler, stmt)
+ _column_as_key, _getattr_col_key, _col_bind_name = _key_getters_for_crud_column(
+ compiler, stmt
+ )
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
parameters = {}
else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in compiler.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
+ parameters = dict(
+ (_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or key not in stmt_parameters
+ )
# create a list of column assignment clauses as tuples
values = []
if stmt_parameters is not None:
_get_stmt_parameters_params(
- compiler,
- parameters, stmt_parameters, _column_as_key, values, kw)
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw
+ )
check_columns = {}
@@ -122,28 +126,51 @@ def _get_crud_params(compiler, stmt, **kw):
# statements
if compiler.isupdate and stmt._extra_froms and stmt_parameters:
_get_multitable_params(
- compiler, stmt, stmt_parameters, check_columns,
- _col_bind_name, _getattr_col_key, values, kw)
+ compiler,
+ stmt,
+ stmt_parameters,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+ )
if compiler.isinsert and stmt.select_names:
_scan_insert_from_select_cols(
- compiler, stmt, parameters,
- _getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
else:
_scan_cols(
- compiler, stmt, parameters,
- _getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
if parameters and stmt_parameters:
- check = set(parameters).intersection(
- _column_as_key(k) for k in stmt_parameters
- ).difference(check_columns)
+ check = (
+ set(parameters)
+ .intersection(_column_as_key(k) for k in stmt_parameters)
+ .difference(check_columns)
+ )
if check:
raise exc.CompileError(
- "Unconsumed column names: %s" %
- (", ".join("%s" % c for c in check))
+ "Unconsumed column names: %s"
+ % (", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
@@ -153,12 +180,13 @@ def _get_crud_params(compiler, stmt, **kw):
def _create_bind_param(
- compiler, col, value, process=True,
- required=False, name=None, **kw):
+ compiler, col, value, process=True, required=False, name=None, **kw
+):
if name is None:
name = col.key
bindparam = elements.BindParameter(
- name, value, type_=col.type, required=required)
+ name, value, type_=col.type, required=required
+ )
bindparam._is_crud = True
if process:
bindparam = bindparam._compiler_dispatch(compiler, **kw)
@@ -177,7 +205,7 @@ def _key_getters_for_crud_column(compiler, stmt):
def _column_as_key(key):
str_key = elements._column_as_key(key)
- if hasattr(key, 'table') and key.table in _et:
+ if hasattr(key, "table") and key.table in _et:
return (key.table.name, str_key)
else:
return str_key
@@ -202,15 +230,22 @@ def _key_getters_for_crud_column(compiler, stmt):
def _scan_insert_from_select_cols(
- compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
-
- need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid = \
- _get_returning_modifiers(compiler, stmt)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers(
+ compiler, stmt
+ )
- cols = [stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names]
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
compiler._insert_from_select = stmt.select
@@ -228,32 +263,39 @@ def _scan_insert_from_select_cols(
values.append((c, None))
else:
_append_param_insert_select_hasdefault(
- compiler, stmt, c, add_select_cols, kw)
+ compiler, stmt, c, add_select_cols, kw
+ )
if add_select_cols:
values.extend(add_select_cols)
compiler._insert_from_select = compiler._insert_from_select._generate()
- compiler._insert_from_select._raw_columns = \
- tuple(compiler._insert_from_select._raw_columns) + tuple(
- expr for col, expr in add_select_cols)
+ compiler._insert_from_select._raw_columns = tuple(
+ compiler._insert_from_select._raw_columns
+ ) + tuple(expr for col, expr in add_select_cols)
def _scan_cols(
- compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
-
- need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid = \
- _get_returning_modifiers(compiler, stmt)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers(
+ compiler, stmt
+ )
if stmt._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in stmt._parameter_ordering
]
ordered_keys = set(parameter_ordering)
- cols = [
- stmt.table.c[key] for key in parameter_ordering
- ] + [
+ cols = [stmt.table.c[key] for key in parameter_ordering] + [
c for c in stmt.table.c if c.key not in ordered_keys
]
else:
@@ -265,72 +307,95 @@ def _scan_cols(
if col_key in parameters and col_key not in check_columns:
_append_param_parameter(
- compiler, stmt, c, col_key, parameters, _col_bind_name,
- implicit_returning, implicit_return_defaults, values, kw)
+ compiler,
+ stmt,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+ )
elif compiler.isinsert:
- if c.primary_key and \
- need_pks and \
- (
- implicit_returning or
- not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
- ):
+ if (
+ c.primary_key
+ and need_pks
+ and (
+ implicit_returning
+ or not postfetch_lastrowid
+ or c is not stmt.table._autoincrement_column
+ )
+ ):
if implicit_returning:
_append_param_insert_pk_returning(
- compiler, stmt, c, values, kw)
+ compiler, stmt, c, values, kw
+ )
else:
_append_param_insert_pk(compiler, stmt, c, values, kw)
elif c.default is not None:
_append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults,
- values, kw)
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
elif c.server_default is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
- elif c.primary_key and \
- c is not stmt.table._autoincrement_column and \
- not c.nullable:
+ elif (
+ c.primary_key
+ and c is not stmt.table._autoincrement_column
+ and not c.nullable
+ ):
_warn_pk_with_no_anticipated_value(c)
elif compiler.isupdate:
_append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw)
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
def _append_param_parameter(
- compiler, stmt, c, col_key, parameters, _col_bind_name,
- implicit_returning, implicit_return_defaults, values, kw):
+ compiler,
+ stmt,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+):
value = parameters.pop(col_key)
if elements._is_literal(value):
value = _create_bind_param(
- compiler, c, value, required=value is REQUIRED,
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
name=_col_bind_name(c)
if not stmt._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
- if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
+ if isinstance(value, elements.BindParameter) and value.type._isnull:
value = value._clone()
value.type = c.type
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
else:
@@ -358,22 +423,20 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
"""
if c.default is not None:
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional
+ or not compiler.dialect.sequences_optional
+ ):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
compiler.returning.append(c)
elif c.default.is_clause_element:
values.append(
- (c, compiler.process(
- c.default.arg.self_group(), **kw))
+ (c, compiler.process(c.default.arg.self_group(), **kw))
)
compiler.returning.append(c)
else:
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
elif not c.nullable:
@@ -405,9 +468,11 @@ class _multiparam_column(elements.ColumnElement):
self.type = original.type
def __eq__(self, other):
- return isinstance(other, _multiparam_column) and \
- other.key == self.key and \
- other.original == self.original
+ return (
+ isinstance(other, _multiparam_column)
+ and other.key == self.key
+ and other.original == self.original
+ )
def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
@@ -416,7 +481,8 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
- "a Python-side value or SQL expression is required" % c)
+ "a Python-side value or SQL expression is required" % c
+ )
elif c.default.is_clause_element:
return compiler.process(c.default.arg.self_group(), **kw)
else:
@@ -440,30 +506,24 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
"""
if (
- (
- # column has a Python-side default
- c.default is not None and
- (
- # and it won't be a Sequence
- not c.default.is_sequence or
- compiler.dialect.supports_sequences
- )
- )
- or
- (
- # column is the "autoincrement column"
- c is stmt.table._autoincrement_column and
- (
- # and it's either a "sequence" or a
- # pre-executable "autoincrement" sequence
- compiler.dialect.supports_sequences or
- compiler.dialect.preexecute_autoincrement_sequences
- )
- )
- ):
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
+ # column has a Python-side default
+ c.default is not None
+ and (
+ # and it won't be a Sequence
+ not c.default.is_sequence
+ or compiler.dialect.supports_sequences
)
+ ) or (
+ # column is the "autoincrement column"
+ c is stmt.table._autoincrement_column
+ and (
+ # and it's either a "sequence" or a
+ # pre-executable "autoincrement" sequence
+ compiler.dialect.supports_sequences
+ or compiler.dialect.preexecute_autoincrement_sequences
+ )
+ ):
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
@@ -471,16 +531,16 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
def _append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults, values, kw):
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
@@ -488,25 +548,21 @@ def _append_param_insert_hasdefault(
proc = compiler.process(c.default.arg.self_group(), **kw)
values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
# don't add primary key column to postfetch
compiler.postfetch.append(c)
else:
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
-def _append_param_insert_select_hasdefault(
- compiler, stmt, c, values, kw):
+def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
proc = c.default
values.append((c, proc.next_value()))
elif c.default.is_clause_element:
@@ -519,38 +575,43 @@ def _append_param_insert_select_hasdefault(
def _append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw):
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(
- c.onupdate.arg.self_group(), **kw))
+ (c, compiler.process(c.onupdate.arg.self_group(), **kw))
)
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
else:
- values.append(
- (c, _create_update_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_update_prefetch_bind_param(compiler, c)))
elif c.server_onupdate is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
- elif implicit_return_defaults and \
- stmt._return_defaults is not True and \
- c in implicit_return_defaults:
+ elif (
+ implicit_return_defaults
+ and stmt._return_defaults is not True
+ and c in implicit_return_defaults
+ ):
compiler.returning.append(c)
def _get_multitable_params(
- compiler, stmt, stmt_parameters, check_columns,
- _col_bind_name, _getattr_col_key, values, kw):
+ compiler,
+ stmt,
+ stmt_parameters,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+):
normalized_params = dict(
(elements._clause_element_as_expr(c), param)
@@ -565,8 +626,12 @@ def _get_multitable_params(
value = normalized_params[c]
if elements._is_literal(value):
value = _create_bind_param(
- compiler, c, value, required=value is REQUIRED,
- name=_col_bind_name(c))
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
+ name=_col_bind_name(c),
+ )
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
@@ -577,20 +642,25 @@ def _get_multitable_params(
for c in t.c:
if c in normalized_params:
continue
- elif (c.onupdate is not None and not
- c.onupdate.is_sequence):
+ elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(
- c.onupdate.arg.self_group(),
- **kw)
- )
+ (
+ c,
+ compiler.process(
+ c.onupdate.arg.self_group(), **kw
+ ),
+ )
)
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_update_prefetch_bind_param(
- compiler, c, name=_col_bind_name(c)))
+ (
+ c,
+ _create_update_prefetch_bind_param(
+ compiler, c, name=_col_bind_name(c)
+ ),
+ )
)
elif c.server_onupdate is not None:
compiler.postfetch.append(c)
@@ -608,8 +678,11 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
if elements._is_literal(row[key]):
new_param = _create_bind_param(
- compiler, col, row[key],
- name="%s_m%d" % (col.key, i + 1), **kw
+ compiler,
+ col,
+ row[key],
+ name="%s_m%d" % (col.key, i + 1),
+ **kw
)
else:
new_param = compiler.process(row[key].self_group(), **kw)
@@ -626,7 +699,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
def _get_stmt_parameters_params(
- compiler, parameters, stmt_parameters, _column_as_key, values, kw):
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw
+):
for k, v in stmt_parameters.items():
colkey = _column_as_key(k)
if colkey is not None:
@@ -637,8 +711,8 @@ def _get_stmt_parameters_params(
# coercing right side to bound param
if elements._is_literal(v):
v = compiler.process(
- elements.BindParameter(None, v, type_=k.type),
- **kw)
+ elements.BindParameter(None, v, type_=k.type), **kw
+ )
else:
v = compiler.process(v.self_group(), **kw)
@@ -646,22 +720,27 @@ def _get_stmt_parameters_params(
def _get_returning_modifiers(compiler, stmt):
- need_pks = compiler.isinsert and \
- not compiler.inline and \
- not stmt._returning and \
- not stmt._has_multi_parameters
+ need_pks = (
+ compiler.isinsert
+ and not compiler.inline
+ and not stmt._returning
+ and not stmt._has_multi_parameters
+ )
- implicit_returning = need_pks and \
- compiler.dialect.implicit_returning and \
- stmt.table.implicit_returning
+ implicit_returning = (
+ need_pks
+ and compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ )
if compiler.isinsert:
- implicit_return_defaults = (implicit_returning and
- stmt._return_defaults)
+ implicit_return_defaults = implicit_returning and stmt._return_defaults
elif compiler.isupdate:
- implicit_return_defaults = (compiler.dialect.implicit_returning and
- stmt.table.implicit_returning and
- stmt._return_defaults)
+ implicit_return_defaults = (
+ compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ and stmt._return_defaults
+ )
else:
# this line is unused, currently we are always
# isinsert or isupdate
@@ -675,8 +754,12 @@ def _get_returning_modifiers(compiler, stmt):
postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
- return need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid
+ return (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ )
def _warn_pk_with_no_anticipated_value(c):
@@ -687,8 +770,8 @@ def _warn_pk_with_no_anticipated_value(c):
"nor does it indicate 'autoincrement=True' or 'nullable=True', "
"and no explicit value is passed. "
"Primary key columns typically may not store NULL."
- %
- (c.table.fullname, c.name, c.table.fullname))
+ % (c.table.fullname, c.name, c.table.fullname)
+ )
if len(c.table.primary_key) > 1:
msg += (
" Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
@@ -696,5 +779,6 @@ def _warn_pk_with_no_anticipated_value(c):
"keys if AUTO_INCREMENT/SERIAL/IDENTITY "
"behavior is expected for one of the columns in the primary key. "
"CREATE TABLE statements are impacted by this change as well on "
- "most backends.")
+ "most backends."
+ )
util.warn(msg)
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 91e93efe7..f21b3d7f0 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -56,8 +56,9 @@ class DDLElement(Executable, _DDLCompiles):
"""
- _execution_options = Executable.\
- _execution_options.union({'autocommit': True})
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
target = None
on = None
@@ -95,11 +96,13 @@ class DDLElement(Executable, _DDLCompiles):
if self._should_execute(target, bind):
return bind.execute(self.against(target))
else:
- bind.engine.logger.info(
- "DDL execution skipped, criteria not met.")
+ bind.engine.logger.info("DDL execution skipped, criteria not met.")
- @util.deprecated("0.7", "See :class:`.DDLEvents`, as well as "
- ":meth:`.DDLElement.execute_if`.")
+ @util.deprecated(
+ "0.7",
+ "See :class:`.DDLEvents`, as well as "
+ ":meth:`.DDLElement.execute_if`.",
+ )
def execute_at(self, event_name, target):
"""Link execution of this DDL to the DDL lifecycle of a SchemaItem.
@@ -129,11 +132,12 @@ class DDLElement(Executable, _DDLCompiles):
"""
def call_event(target, connection, **kw):
- if self._should_execute_deprecated(event_name,
- target, connection, **kw):
+ if self._should_execute_deprecated(
+ event_name, target, connection, **kw
+ ):
return connection.execute(self.against(target))
- event.listen(target, "" + event_name.replace('-', '_'), call_event)
+ event.listen(target, "" + event_name.replace("-", "_"), call_event)
@_generative
def against(self, target):
@@ -211,8 +215,9 @@ class DDLElement(Executable, _DDLCompiles):
self.state = state
def _should_execute(self, target, bind, **kw):
- if self.on is not None and \
- not self._should_execute_deprecated(None, target, bind, **kw):
+ if self.on is not None and not self._should_execute_deprecated(
+ None, target, bind, **kw
+ ):
return False
if isinstance(self.dialect, util.string_types):
@@ -221,9 +226,9 @@ class DDLElement(Executable, _DDLCompiles):
elif isinstance(self.dialect, (tuple, list, set)):
if bind.engine.name not in self.dialect:
return False
- if (self.callable_ is not None and
- not self.callable_(self, target, bind,
- state=self.state, **kw)):
+ if self.callable_ is not None and not self.callable_(
+ self, target, bind, state=self.state, **kw
+ ):
return False
return True
@@ -245,13 +250,15 @@ class DDLElement(Executable, _DDLCompiles):
return bind.execute(self.against(target))
def _check_ddl_on(self, on):
- if (on is not None and
- (not isinstance(on, util.string_types + (tuple, list, set)) and
- not util.callable(on))):
+ if on is not None and (
+ not isinstance(on, util.string_types + (tuple, list, set))
+ and not util.callable(on)
+ ):
raise exc.ArgumentError(
"Expected the name of a database dialect, a tuple "
"of names, or a callable for "
- "'on' criteria, got type '%s'." % type(on).__name__)
+ "'on' criteria, got type '%s'." % type(on).__name__
+ )
def bind(self):
if self._bind:
@@ -259,6 +266,7 @@ class DDLElement(Executable, _DDLCompiles):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
def _generate(self):
@@ -375,8 +383,9 @@ class DDL(DDLElement):
if not isinstance(statement, util.string_types):
raise exc.ArgumentError(
- "Expected a string or unicode SQL statement, got '%r'" %
- statement)
+ "Expected a string or unicode SQL statement, got '%r'"
+ % statement
+ )
self.statement = statement
self.context = context or {}
@@ -386,12 +395,18 @@ class DDL(DDLElement):
self._bind = bind
def __repr__(self):
- return '<%s@%s; %s>' % (
- type(self).__name__, id(self),
- ', '.join([repr(self.statement)] +
- ['%s=%r' % (key, getattr(self, key))
- for key in ('on', 'context')
- if getattr(self, key)]))
+ return "<%s@%s; %s>" % (
+ type(self).__name__,
+ id(self),
+ ", ".join(
+ [repr(self.statement)]
+ + [
+ "%s=%r" % (key, getattr(self, key))
+ for key in ("on", "context")
+ if getattr(self, key)
+ ]
+ ),
+ )
class _CreateDropBase(DDLElement):
@@ -464,8 +479,8 @@ class CreateTable(_CreateDropBase):
__visit_name__ = "create_table"
def __init__(
- self, element, on=None, bind=None,
- include_foreign_key_constraints=None):
+ self, element, on=None, bind=None, include_foreign_key_constraints=None
+ ):
"""Create a :class:`.CreateTable` construct.
:param element: a :class:`.Table` that's the subject
@@ -481,9 +496,7 @@ class CreateTable(_CreateDropBase):
"""
super(CreateTable, self).__init__(element, on=on, bind=bind)
- self.columns = [CreateColumn(column)
- for column in element.columns
- ]
+ self.columns = [CreateColumn(column) for column in element.columns]
self.include_foreign_key_constraints = include_foreign_key_constraints
@@ -494,6 +507,7 @@ class _DropView(_CreateDropBase):
This object will eventually be part of a public "view" API.
"""
+
__visit_name__ = "drop_view"
@@ -602,7 +616,8 @@ class CreateColumn(_DDLCompiles):
to support custom column creation styles.
"""
- __visit_name__ = 'create_column'
+
+ __visit_name__ = "create_column"
def __init__(self, element):
self.element = element
@@ -646,7 +661,8 @@ class AddConstraint(_CreateDropBase):
def __init__(self, element, *args, **kw):
super(AddConstraint, self).__init__(element, *args, **kw)
element._create_rule = util.portable_instancemethod(
- self._create_rule_disable)
+ self._create_rule_disable
+ )
class DropConstraint(_CreateDropBase):
@@ -658,7 +674,8 @@ class DropConstraint(_CreateDropBase):
self.cascade = cascade
super(DropConstraint, self).__init__(element, **kw)
element._create_rule = util.portable_instancemethod(
- self._create_rule_disable)
+ self._create_rule_disable
+ )
class SetTableComment(_CreateDropBase):
@@ -691,9 +708,9 @@ class DDLBase(SchemaVisitor):
class SchemaGenerator(DDLBase):
-
- def __init__(self, dialect, connection, checkfirst=False,
- tables=None, **kwargs):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
super(SchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
@@ -706,25 +723,22 @@ class SchemaGenerator(DDLBase):
effective_schema = self.connection.schema_for_object(table)
if effective_schema:
self.dialect.validate_identifier(effective_schema)
- return not self.checkfirst or \
- not self.dialect.has_table(self.connection,
- table.name, schema=effective_schema)
+ return not self.checkfirst or not self.dialect.has_table(
+ self.connection, table.name, schema=effective_schema
+ )
def _can_create_sequence(self, sequence):
effective_schema = self.connection.schema_for_object(sequence)
- return self.dialect.supports_sequences and \
- (
- (not self.dialect.sequences_optional or
- not sequence.optional) and
- (
- not self.checkfirst or
- not self.dialect.has_sequence(
- self.connection,
- sequence.name,
- schema=effective_schema)
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or not self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
)
)
+ )
def visit_metadata(self, metadata):
if self.tables is not None:
@@ -733,18 +747,23 @@ class SchemaGenerator(DDLBase):
tables = list(metadata.tables.values())
collection = sort_tables_and_constraints(
- [t for t in tables if self._can_create_table(t)])
-
- seq_coll = [s for s in metadata._sequences.values()
- if s.column is None and self._can_create_sequence(s)]
+ [t for t in tables if self._can_create_table(t)]
+ )
- event_collection = [
- t for (t, fks) in collection if t is not None
+ seq_coll = [
+ s
+ for s in metadata._sequences.values()
+ if s.column is None and self._can_create_sequence(s)
]
- metadata.dispatch.before_create(metadata, self.connection,
- tables=event_collection,
- checkfirst=self.checkfirst,
- _ddl_runner=self)
+
+ event_collection = [t for (t, fks) in collection if t is not None]
+ metadata.dispatch.before_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
for seq in seq_coll:
self.traverse_single(seq, create_ok=True)
@@ -752,30 +771,40 @@ class SchemaGenerator(DDLBase):
for table, fkcs in collection:
if table is not None:
self.traverse_single(
- table, create_ok=True,
+ table,
+ create_ok=True,
include_foreign_key_constraints=fkcs,
- _is_metadata_operation=True)
+ _is_metadata_operation=True,
+ )
else:
for fkc in fkcs:
self.traverse_single(fkc)
- metadata.dispatch.after_create(metadata, self.connection,
- tables=event_collection,
- checkfirst=self.checkfirst,
- _ddl_runner=self)
+ metadata.dispatch.after_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
def visit_table(
- self, table, create_ok=False,
- include_foreign_key_constraints=None,
- _is_metadata_operation=False):
+ self,
+ table,
+ create_ok=False,
+ include_foreign_key_constraints=None,
+ _is_metadata_operation=False,
+ ):
if not create_ok and not self._can_create_table(table):
return
table.dispatch.before_create(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
for column in table.columns:
if column.default is not None:
@@ -788,10 +817,11 @@ class SchemaGenerator(DDLBase):
self.connection.execute(
CreateTable(
table,
- include_foreign_key_constraints=include_foreign_key_constraints
- ))
+ include_foreign_key_constraints=include_foreign_key_constraints,
+ )
+ )
- if hasattr(table, 'indexes'):
+ if hasattr(table, "indexes"):
for index in table.indexes:
self.traverse_single(index)
@@ -804,10 +834,12 @@ class SchemaGenerator(DDLBase):
self.connection.execute(SetColumnComment(column))
table.dispatch.after_create(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
@@ -824,9 +856,9 @@ class SchemaGenerator(DDLBase):
class SchemaDropper(DDLBase):
-
- def __init__(self, dialect, connection, checkfirst=False,
- tables=None, **kwargs):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
super(SchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
@@ -842,15 +874,17 @@ class SchemaDropper(DDLBase):
try:
unsorted_tables = [t for t in tables if self._can_drop_table(t)]
- collection = list(reversed(
- sort_tables_and_constraints(
- unsorted_tables,
- filter_fn=lambda constraint: False
- if not self.dialect.supports_alter
- or constraint.name is None
- else None
+ collection = list(
+ reversed(
+ sort_tables_and_constraints(
+ unsorted_tables,
+ filter_fn=lambda constraint: False
+ if not self.dialect.supports_alter
+ or constraint.name is None
+ else None,
+ )
)
- ))
+ )
except exc.CircularDependencyError as err2:
if not self.dialect.supports_alter:
util.warn(
@@ -862,16 +896,15 @@ class SchemaDropper(DDLBase):
"ForeignKeyConstraint "
"objects involved in the cycle to mark these as known "
"cycles that will be ignored."
- % (
- ", ".join(sorted([t.fullname for t in err2.cycles]))
- )
+ % (", ".join(sorted([t.fullname for t in err2.cycles])))
)
collection = [(t, ()) for t in unsorted_tables]
else:
util.raise_from_cause(
exc.CircularDependencyError(
err2.args[0],
- err2.cycles, err2.edges,
+ err2.cycles,
+ err2.edges,
msg="Can't sort tables for DROP; an "
"unresolvable foreign key "
"dependency exists between tables: %s. Please ensure "
@@ -880,9 +913,10 @@ class SchemaDropper(DDLBase):
"names so that they can be dropped using "
"DROP CONSTRAINT."
% (
- ", ".join(sorted([t.fullname for t in err2.cycles]))
- )
-
+ ", ".join(
+ sorted([t.fullname for t in err2.cycles])
+ )
+ ),
)
)
@@ -892,18 +926,21 @@ class SchemaDropper(DDLBase):
if s.column is None and self._can_drop_sequence(s)
]
- event_collection = [
- t for (t, fks) in collection if t is not None
- ]
+ event_collection = [t for (t, fks) in collection if t is not None]
metadata.dispatch.before_drop(
- metadata, self.connection, tables=event_collection,
- checkfirst=self.checkfirst, _ddl_runner=self)
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
for table, fkcs in collection:
if table is not None:
self.traverse_single(
- table, drop_ok=True, _is_metadata_operation=True)
+ table, drop_ok=True, _is_metadata_operation=True
+ )
else:
for fkc in fkcs:
self.traverse_single(fkc)
@@ -912,8 +949,12 @@ class SchemaDropper(DDLBase):
self.traverse_single(seq, drop_ok=True)
metadata.dispatch.after_drop(
- metadata, self.connection, tables=event_collection,
- checkfirst=self.checkfirst, _ddl_runner=self)
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
def _can_drop_table(self, table):
self.dialect.validate_identifier(table.name)
@@ -921,19 +962,20 @@ class SchemaDropper(DDLBase):
if effective_schema:
self.dialect.validate_identifier(effective_schema)
return not self.checkfirst or self.dialect.has_table(
- self.connection, table.name, schema=effective_schema)
+ self.connection, table.name, schema=effective_schema
+ )
def _can_drop_sequence(self, sequence):
effective_schema = self.connection.schema_for_object(sequence)
- return self.dialect.supports_sequences and \
- ((not self.dialect.sequences_optional or
- not sequence.optional) and
- (not self.checkfirst or
- self.dialect.has_sequence(
- self.connection,
- sequence.name,
- schema=effective_schema))
- )
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
+ )
+ )
+ )
def visit_index(self, index):
self.connection.execute(DropIndex(index))
@@ -943,10 +985,12 @@ class SchemaDropper(DDLBase):
return
table.dispatch.before_drop(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
self.connection.execute(DropTable(table))
@@ -960,10 +1004,12 @@ class SchemaDropper(DDLBase):
self.traverse_single(column.default)
table.dispatch.after_drop(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
@@ -1019,25 +1065,29 @@ def sort_tables(tables, skip_fn=None, extra_dependencies=None):
"""
if skip_fn is not None:
+
def _skip_fn(fkc):
for fk in fkc.elements:
if skip_fn(fk):
return True
else:
return None
+
else:
_skip_fn = None
return [
- t for (t, fkcs) in
- sort_tables_and_constraints(
- tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies)
+ t
+ for (t, fkcs) in sort_tables_and_constraints(
+ tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies
+ )
if t is not None
]
def sort_tables_and_constraints(
- tables, filter_fn=None, extra_dependencies=None):
+ tables, filter_fn=None, extra_dependencies=None
+):
"""sort a collection of :class:`.Table` / :class:`.ForeignKeyConstraint`
objects.
@@ -1109,8 +1159,9 @@ def sort_tables_and_constraints(
try:
candidate_sort = list(
topological.sort(
- fixed_dependencies.union(mutable_dependencies), tables,
- deterministic_order=True
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ deterministic_order=True,
)
)
except exc.CircularDependencyError as err:
@@ -1118,8 +1169,10 @@ def sort_tables_and_constraints(
if edge in mutable_dependencies:
table = edge[1]
can_remove = [
- fkc for fkc in table.foreign_key_constraints
- if filter_fn is None or filter_fn(fkc) is not False]
+ fkc
+ for fkc in table.foreign_key_constraints
+ if filter_fn is None or filter_fn(fkc) is not False
+ ]
remaining_fkcs.update(can_remove)
for fkc in can_remove:
dependent_on = fkc.referred_table
@@ -1127,8 +1180,9 @@ def sort_tables_and_constraints(
mutable_dependencies.discard((dependent_on, table))
candidate_sort = list(
topological.sort(
- fixed_dependencies.union(mutable_dependencies), tables,
- deterministic_order=True
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ deterministic_order=True,
)
)
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 8149f9731..fa0052198 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -11,19 +11,43 @@
from .. import exc, util
from . import type_api
from . import operators
-from .elements import BindParameter, True_, False_, BinaryExpression, \
- Null, _const_expr, _clause_element_as_expr, \
- ClauseList, ColumnElement, TextClause, UnaryExpression, \
- collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
- Slice, Visitable, _literal_as_binds, CollectionAggregate, \
- Tuple
+from .elements import (
+ BindParameter,
+ True_,
+ False_,
+ BinaryExpression,
+ Null,
+ _const_expr,
+ _clause_element_as_expr,
+ ClauseList,
+ ColumnElement,
+ TextClause,
+ UnaryExpression,
+ collate,
+ _is_literal,
+ _literal_as_text,
+ ClauseElement,
+ and_,
+ or_,
+ Slice,
+ Visitable,
+ _literal_as_binds,
+ CollectionAggregate,
+ Tuple,
+)
from .selectable import SelectBase, Alias, Selectable, ScalarSelect
-def _boolean_compare(expr, op, obj, negate=None, reverse=False,
- _python_is_types=(util.NoneType, bool),
- result_type = None,
- **kwargs):
+def _boolean_compare(
+ expr,
+ op,
+ obj,
+ negate=None,
+ reverse=False,
+ _python_is_types=(util.NoneType, bool),
+ result_type=None,
+ **kwargs
+):
if result_type is None:
result_type = type_api.BOOLEANTYPE
@@ -33,57 +57,64 @@ def _boolean_compare(expr, op, obj, negate=None, reverse=False,
# allow x ==/!= True/False to be treated as a literal.
# this comes out to "== / != true/false" or "1/0" if those
# constants aren't supported and works on all platforms
- if op in (operators.eq, operators.ne) and \
- isinstance(obj, (bool, True_, False_)):
- return BinaryExpression(expr,
- _literal_as_text(obj),
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ if op in (operators.eq, operators.ne) and isinstance(
+ obj, (bool, True_, False_)
+ ):
+ return BinaryExpression(
+ expr,
+ _literal_as_text(obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
elif op in (operators.is_distinct_from, operators.isnot_distinct_from):
- return BinaryExpression(expr,
- _literal_as_text(obj),
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ expr,
+ _literal_as_text(obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
else:
# all other None/True/False uses IS, IS NOT
if op in (operators.eq, operators.is_):
- return BinaryExpression(expr, _const_expr(obj),
- operators.is_,
- negate=operators.isnot,
- type_=result_type
- )
+ return BinaryExpression(
+ expr,
+ _const_expr(obj),
+ operators.is_,
+ negate=operators.isnot,
+ type_=result_type,
+ )
elif op in (operators.ne, operators.isnot):
- return BinaryExpression(expr, _const_expr(obj),
- operators.isnot,
- negate=operators.is_,
- type_=result_type
- )
+ return BinaryExpression(
+ expr,
+ _const_expr(obj),
+ operators.isnot,
+ negate=operators.is_,
+ type_=result_type,
+ )
else:
raise exc.ArgumentError(
"Only '=', '!=', 'is_()', 'isnot()', "
"'is_distinct_from()', 'isnot_distinct_from()' "
- "operators can be used with None/True/False")
+ "operators can be used with None/True/False"
+ )
else:
obj = _check_literal(expr, op, obj)
if reverse:
- return BinaryExpression(obj,
- expr,
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
else:
- return BinaryExpression(expr,
- obj,
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
-def _custom_op_operate(expr, op, obj, reverse=False, result_type=None,
- **kw):
+def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
if result_type is None:
if op.return_type:
result_type = op.return_type
@@ -91,11 +122,11 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None,
result_type = type_api.BOOLEANTYPE
return _binary_operate(
- expr, op, obj, reverse=reverse, result_type=result_type, **kw)
+ expr, op, obj, reverse=reverse, result_type=result_type, **kw
+ )
-def _binary_operate(expr, op, obj, reverse=False, result_type=None,
- **kw):
+def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw):
obj = _check_literal(expr, op, obj)
if reverse:
@@ -105,10 +136,10 @@ def _binary_operate(expr, op, obj, reverse=False, result_type=None,
if result_type is None:
op, result_type = left.comparator._adapt_expression(
- op, right.comparator)
+ op, right.comparator
+ )
- return BinaryExpression(
- left, right, op, type_=result_type, modifiers=kw)
+ return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
def _conjunction_operate(expr, op, other, **kw):
@@ -128,8 +159,7 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
if isinstance(seq_or_selectable, ScalarSelect):
- return _boolean_compare(expr, op, seq_or_selectable,
- negate=negate_op)
+ return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op)
elif isinstance(seq_or_selectable, SelectBase):
# TODO: if we ever want to support (x, y, z) IN (select x,
@@ -138,32 +168,33 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
# does not export itself as a FROM clause
return _boolean_compare(
- expr, op, seq_or_selectable.as_scalar(),
- negate=negate_op, **kw)
+ expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw
+ )
elif isinstance(seq_or_selectable, (Selectable, TextClause)):
- return _boolean_compare(expr, op, seq_or_selectable,
- negate=negate_op, **kw)
+ return _boolean_compare(
+ expr, op, seq_or_selectable, negate=negate_op, **kw
+ )
elif isinstance(seq_or_selectable, ClauseElement):
- if isinstance(seq_or_selectable, BindParameter) and \
- seq_or_selectable.expanding:
+ if (
+ isinstance(seq_or_selectable, BindParameter)
+ and seq_or_selectable.expanding
+ ):
if isinstance(expr, Tuple):
- seq_or_selectable = (
- seq_or_selectable._with_expanding_in_types(
- [elem.type for elem in expr]
- )
+ seq_or_selectable = seq_or_selectable._with_expanding_in_types(
+ [elem.type for elem in expr]
)
return _boolean_compare(
- expr, op,
- seq_or_selectable,
- negate=negate_op)
+ expr, op, seq_or_selectable, negate=negate_op
+ )
else:
raise exc.InvalidRequestError(
- 'in_() accepts'
- ' either a list of expressions, '
+ "in_() accepts"
+ " either a list of expressions, "
'a selectable, or an "expanding" bound parameter: %r'
- % seq_or_selectable)
+ % seq_or_selectable
+ )
# Handle non selectable arguments as sequences
args = []
@@ -171,9 +202,10 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
raise exc.InvalidRequestError(
- 'in_() accepts'
- ' either a list of expressions, '
- 'a selectable, or an "expanding" bound parameter: %r' % o)
+ "in_() accepts"
+ " either a list of expressions, "
+ 'a selectable, or an "expanding" bound parameter: %r' % o
+ )
elif o is None:
o = Null()
else:
@@ -182,15 +214,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
if len(args) == 0:
op, negate_op = (
- operators.empty_in_op,
- operators.empty_notin_op) if op is operators.in_op \
- else (
- operators.empty_notin_op,
- operators.empty_in_op)
+ (operators.empty_in_op, operators.empty_notin_op)
+ if op is operators.in_op
+ else (operators.empty_notin_op, operators.empty_in_op)
+ )
- return _boolean_compare(expr, op,
- ClauseList(*args).self_group(against=op),
- negate=negate_op)
+ return _boolean_compare(
+ expr, op, ClauseList(*args).self_group(against=op), negate=negate_op
+ )
def _getitem_impl(expr, op, other, **kw):
@@ -202,13 +233,14 @@ def _getitem_impl(expr, op, other, **kw):
def _unsupported_impl(expr, op, *arg, **kw):
- raise NotImplementedError("Operator '%s' is not supported on "
- "this expression" % op.__name__)
+ raise NotImplementedError(
+ "Operator '%s' is not supported on " "this expression" % op.__name__
+ )
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
- if hasattr(expr, 'negation_clause'):
+ if hasattr(expr, "negation_clause"):
return expr.negation_clause
else:
return expr._negate()
@@ -223,20 +255,22 @@ def _match_impl(expr, op, other, **kw):
"""See :meth:`.ColumnOperators.match`."""
return _boolean_compare(
- expr, operators.match_op,
- _check_literal(
- expr, operators.match_op, other),
+ expr,
+ operators.match_op,
+ _check_literal(expr, operators.match_op, other),
result_type=type_api.MATCHTYPE,
negate=operators.notmatch_op
- if op is operators.match_op else operators.match_op,
+ if op is operators.match_op
+ else operators.match_op,
**kw
)
def _distinct_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.distinct`."""
- return UnaryExpression(expr, operator=operators.distinct_op,
- type_=expr.type)
+ return UnaryExpression(
+ expr, operator=operators.distinct_op, type_=expr.type
+ )
def _between_impl(expr, op, cleft, cright, **kw):
@@ -247,17 +281,21 @@ def _between_impl(expr, op, cleft, cright, **kw):
_check_literal(expr, operators.and_, cleft),
_check_literal(expr, operators.and_, cright),
operator=operators.and_,
- group=False, group_contents=False),
+ group=False,
+ group_contents=False,
+ ),
op,
negate=operators.notbetween_op
if op is operators.between_op
else operators.between_op,
- modifiers=kw)
+ modifiers=kw,
+ )
def _collate_impl(expr, op, other, **kw):
return collate(expr, other)
+
# a mapping of operators with the method they use, along with
# their negated operator for comparison operators
operator_lookup = {
@@ -271,8 +309,8 @@ operator_lookup = {
"mod": (_binary_operate,),
"truediv": (_binary_operate,),
"custom_op": (_custom_op_operate,),
- "json_path_getitem_op": (_binary_operate, ),
- "json_getitem_op": (_binary_operate, ),
+ "json_path_getitem_op": (_binary_operate,),
+ "json_getitem_op": (_binary_operate,),
"concat_op": (_binary_operate,),
"any_op": (_scalar, CollectionAggregate._create_any),
"all_op": (_scalar, CollectionAggregate._create_all),
@@ -303,8 +341,8 @@ operator_lookup = {
"match_op": (_match_impl,),
"notmatch_op": (_match_impl,),
"distinct_op": (_distinct_impl,),
- "between_op": (_between_impl, ),
- "notbetween_op": (_between_impl, ),
+ "between_op": (_between_impl,),
+ "notbetween_op": (_between_impl,),
"neg": (_neg_impl,),
"getitem": (_getitem_impl,),
"lshift": (_unsupported_impl,),
@@ -315,12 +353,11 @@ operator_lookup = {
def _check_literal(expr, operator, other, bindparam_type=None):
if isinstance(other, (ColumnElement, TextClause)):
- if isinstance(other, BindParameter) and \
- other.type._isnull:
+ if isinstance(other, BindParameter) and other.type._isnull:
other = other._clone()
other.type = expr.type
return other
- elif hasattr(other, '__clause_element__'):
+ elif hasattr(other, "__clause_element__"):
other = other.__clause_element__()
elif isinstance(other, type_api.TypeEngine.Comparator):
other = other.expr
@@ -331,4 +368,3 @@ def _check_literal(expr, operator, other, bindparam_type=None):
return expr._bind_param(operator, other, type_=bindparam_type)
else:
return other
-
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index d6890de15..0cea5ccc4 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -9,26 +9,43 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
"""
-from .base import Executable, _generative, _from_objects, DialectKWArgs, \
- ColumnCollection
-from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
- _column_as_key
-from .selectable import _interpret_as_from, _interpret_as_select, \
- HasPrefixes, HasCTE
+from .base import (
+ Executable,
+ _generative,
+ _from_objects,
+ DialectKWArgs,
+ ColumnCollection,
+)
+from .elements import (
+ ClauseElement,
+ _literal_as_text,
+ Null,
+ and_,
+ _clone,
+ _column_as_key,
+)
+from .selectable import (
+ _interpret_as_from,
+ _interpret_as_select,
+ HasPrefixes,
+ HasCTE,
+)
from .. import util
from .. import exc
class UpdateBase(
- HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement):
+ HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement
+):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.
"""
- __visit_name__ = 'update_base'
+ __visit_name__ = "update_base"
- _execution_options = \
- Executable._execution_options.union({'autocommit': True})
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
_hints = util.immutabledict()
_parameter_ordering = None
_prefixes = ()
@@ -37,30 +54,33 @@ class UpdateBase(
def _process_colparams(self, parameters):
def process_single(p):
if isinstance(p, (list, tuple)):
- return dict(
- (c.key, pval)
- for c, pval in zip(self.table.c, p)
- )
+ return dict((c.key, pval) for c, pval in zip(self.table.c, p))
else:
return p
if self._preserve_parameter_order and parameters is not None:
- if not isinstance(parameters, list) or \
- (parameters and not isinstance(parameters[0], tuple)):
+ if not isinstance(parameters, list) or (
+ parameters and not isinstance(parameters[0], tuple)
+ ):
raise ValueError(
"When preserve_parameter_order is True, "
- "values() only accepts a list of 2-tuples")
+ "values() only accepts a list of 2-tuples"
+ )
self._parameter_ordering = [key for key, value in parameters]
return dict(parameters), False
- if (isinstance(parameters, (list, tuple)) and parameters and
- isinstance(parameters[0], (list, tuple, dict))):
+ if (
+ isinstance(parameters, (list, tuple))
+ and parameters
+ and isinstance(parameters[0], (list, tuple, dict))
+ ):
if not self._supports_multi_parameters:
raise exc.InvalidRequestError(
"This construct does not support "
- "multiple parameter sets.")
+ "multiple parameter sets."
+ )
return [process_single(p) for p in parameters], True
else:
@@ -77,7 +97,8 @@ class UpdateBase(
raise NotImplementedError(
"params() is not supported for INSERT/UPDATE/DELETE statements."
" To set the values for an INSERT or UPDATE statement, use"
- " stmt.values(**parameters).")
+ " stmt.values(**parameters)."
+ )
def bind(self):
"""Return a 'bind' linked to this :class:`.UpdateBase`
@@ -88,6 +109,7 @@ class UpdateBase(
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@_generative
@@ -181,15 +203,14 @@ class UpdateBase(
if selectable is None:
selectable = self.table
- self._hints = self._hints.union(
- {(selectable, dialect_name): text})
+ self._hints = self._hints.union({(selectable, dialect_name): text})
class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to
INSERT and UPDATE constructs."""
- __visit_name__ = 'values_base'
+ __visit_name__ = "values_base"
_supports_multi_parameters = False
_has_multi_parameters = False
@@ -199,8 +220,9 @@ class ValuesBase(UpdateBase):
def __init__(self, table, values, prefixes):
self.table = _interpret_as_from(table)
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(values)
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ values
+ )
if prefixes:
self._setup_prefixes(prefixes)
@@ -332,23 +354,27 @@ class ValuesBase(UpdateBase):
"""
if self.select is not None:
raise exc.InvalidRequestError(
- "This construct already inserts from a SELECT")
+ "This construct already inserts from a SELECT"
+ )
if self._has_multi_parameters and kwargs:
raise exc.InvalidRequestError(
- "This construct already has multiple parameter sets.")
+ "This construct already has multiple parameter sets."
+ )
if args:
if len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
- "dictionaries/tuples is accepted positionally.")
+ "dictionaries/tuples is accepted positionally."
+ )
v = args[0]
else:
v = {}
if self.parameters is None:
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(v)
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ v
+ )
else:
if self._has_multi_parameters:
self.parameters = list(self.parameters)
@@ -356,7 +382,8 @@ class ValuesBase(UpdateBase):
if not self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
- "formats in one statement")
+ "formats in one statement"
+ )
self.parameters.extend(p)
else:
@@ -365,14 +392,16 @@ class ValuesBase(UpdateBase):
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
- "formats in one statement")
+ "formats in one statement"
+ )
self.parameters.update(p)
if kwargs:
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't pass kwargs and multiple parameter sets "
- "simultaneously")
+ "simultaneously"
+ )
else:
self.parameters.update(kwargs)
@@ -456,19 +485,22 @@ class Insert(ValuesBase):
:ref:`coretutorial_insert_expressions`
"""
- __visit_name__ = 'insert'
+
+ __visit_name__ = "insert"
_supports_multi_parameters = True
- def __init__(self,
- table,
- values=None,
- inline=False,
- bind=None,
- prefixes=None,
- returning=None,
- return_defaults=False,
- **dialect_kw):
+ def __init__(
+ self,
+ table,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ **dialect_kw
+ ):
"""Construct an :class:`.Insert` object.
Similar functionality is available via the
@@ -526,7 +558,7 @@ class Insert(ValuesBase):
def get_children(self, **kwargs):
if self.select is not None:
- return self.select,
+ return (self.select,)
else:
return ()
@@ -578,11 +610,12 @@ class Insert(ValuesBase):
"""
if self.parameters:
raise exc.InvalidRequestError(
- "This construct already inserts value expressions")
+ "This construct already inserts value expressions"
+ )
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(
- {_column_as_key(n): Null() for n in names})
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ {_column_as_key(n): Null() for n in names}
+ )
self.select_names = names
self.inline = True
@@ -603,19 +636,22 @@ class Update(ValuesBase):
function.
"""
- __visit_name__ = 'update'
-
- def __init__(self,
- table,
- whereclause=None,
- values=None,
- inline=False,
- bind=None,
- prefixes=None,
- returning=None,
- return_defaults=False,
- preserve_parameter_order=False,
- **dialect_kw):
+
+ __visit_name__ = "update"
+
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ preserve_parameter_order=False,
+ **dialect_kw
+ ):
r"""Construct an :class:`.Update` object.
E.g.::
@@ -745,7 +781,7 @@ class Update(ValuesBase):
def get_children(self, **kwargs):
if self._whereclause is not None:
- return self._whereclause,
+ return (self._whereclause,)
else:
return ()
@@ -761,8 +797,9 @@ class Update(ValuesBase):
"""
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause,
- _literal_as_text(whereclause))
+ self._whereclause = and_(
+ self._whereclause, _literal_as_text(whereclause)
+ )
else:
self._whereclause = _literal_as_text(whereclause)
@@ -788,15 +825,17 @@ class Delete(UpdateBase):
"""
- __visit_name__ = 'delete'
-
- def __init__(self,
- table,
- whereclause=None,
- bind=None,
- returning=None,
- prefixes=None,
- **dialect_kw):
+ __visit_name__ = "delete"
+
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ bind=None,
+ returning=None,
+ prefixes=None,
+ **dialect_kw
+ ):
"""Construct :class:`.Delete` object.
Similar functionality is available via the
@@ -847,7 +886,7 @@ class Delete(UpdateBase):
def get_children(self, **kwargs):
if self._whereclause is not None:
- return self._whereclause,
+ return (self._whereclause,)
else:
return ()
@@ -856,8 +895,9 @@ class Delete(UpdateBase):
"""Add the given WHERE clause to a newly returned delete construct."""
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause,
- _literal_as_text(whereclause))
+ self._whereclause = and_(
+ self._whereclause, _literal_as_text(whereclause)
+ )
else:
self._whereclause = _literal_as_text(whereclause)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index de3b7992a..e857f2da8 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -51,9 +51,8 @@ def collate(expression, collation):
expr = _literal_as_binds(expression)
return BinaryExpression(
- expr,
- CollationClause(collation),
- operators.collate, type_=expr.type)
+ expr, CollationClause(collation), operators.collate, type_=expr.type
+ )
def between(expr, lower_bound, upper_bound, symmetric=False):
@@ -130,8 +129,6 @@ def literal(value, type_=None):
return BindParameter(None, value, type_=type_, unique=True)
-
-
def outparam(key, type_=None):
"""Create an 'OUT' parameter for usage in functions (stored procedures),
for databases which support them.
@@ -142,8 +139,7 @@ def outparam(key, type_=None):
attribute, which returns a dictionary containing the values.
"""
- return BindParameter(
- key, None, type_=type_, unique=False, isoutparam=True)
+ return BindParameter(key, None, type_=type_, unique=False, isoutparam=True)
def not_(clause):
@@ -163,7 +159,8 @@ class ClauseElement(Visitable):
expression.
"""
- __visit_name__ = 'clause'
+
+ __visit_name__ = "clause"
_annotations = {}
supports_execution = False
@@ -230,7 +227,7 @@ class ClauseElement(Visitable):
def __getstate__(self):
d = self.__dict__.copy()
- d.pop('_is_clone_of', None)
+ d.pop("_is_clone_of", None)
return d
def _annotate(self, values):
@@ -300,7 +297,8 @@ class ClauseElement(Visitable):
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
raise exc.ArgumentError(
- "params() takes zero or one positional dictionary argument")
+ "params() takes zero or one positional dictionary argument"
+ )
def visit_bindparam(bind):
if bind.key in kwargs:
@@ -308,7 +306,8 @@ class ClauseElement(Visitable):
bind.required = False
if unique:
bind._convert_to_unique()
- return cloned_traverse(self, {}, {'bindparam': visit_bindparam})
+
+ return cloned_traverse(self, {}, {"bindparam": visit_bindparam})
def compare(self, other, **kw):
r"""Compare this ClauseElement to the given ClauseElement.
@@ -451,7 +450,7 @@ class ClauseElement(Visitable):
if util.py3k:
return str(self.compile())
else:
- return unicode(self.compile()).encode('ascii', 'backslashreplace')
+ return unicode(self.compile()).encode("ascii", "backslashreplace")
def __and__(self, other):
"""'and' at the ClauseElement level.
@@ -472,7 +471,7 @@ class ClauseElement(Visitable):
return or_(self, other)
def __invert__(self):
- if hasattr(self, 'negation_clause'):
+ if hasattr(self, "negation_clause"):
return self.negation_clause
else:
return self._negate()
@@ -481,7 +480,8 @@ class ClauseElement(Visitable):
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
- negate=None)
+ negate=None,
+ )
def __bool__(self):
raise TypeError("Boolean value of this clause is not defined")
@@ -493,8 +493,12 @@ class ClauseElement(Visitable):
if friendly is None:
return object.__repr__(self)
else:
- return '<%s.%s at 0x%x; %s>' % (
- self.__module__, self.__class__.__name__, id(self), friendly)
+ return "<%s.%s at 0x%x; %s>" % (
+ self.__module__,
+ self.__class__.__name__,
+ id(self),
+ friendly,
+ )
class ColumnElement(operators.ColumnOperators, ClauseElement):
@@ -571,7 +575,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
"""
- __visit_name__ = 'column_element'
+ __visit_name__ = "column_element"
primary_key = False
foreign_keys = []
@@ -646,11 +650,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
_alt_names = ()
def self_group(self, against=None):
- if (against in (operators.and_, operators.or_, operators._asbool) and
- self.type._type_affinity
- is type_api.BOOLEANTYPE._type_affinity):
+ if (
+ against in (operators.and_, operators.or_, operators._asbool)
+ and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity
+ ):
return AsBoolean(self, operators.istrue, operators.isfalse)
- elif (against in (operators.any_op, operators.all_op)):
+ elif against in (operators.any_op, operators.all_op):
return Grouping(self)
else:
return self
@@ -675,7 +680,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
except AttributeError:
raise TypeError(
"Object %r associated with '.type' attribute "
- "is not a TypeEngine class or object" % self.type)
+ "is not a TypeEngine class or object" % self.type
+ )
else:
return comparator_factory(self)
@@ -684,10 +690,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
return getattr(self.comparator, key)
except AttributeError:
raise AttributeError(
- 'Neither %r object nor %r object has an attribute %r' % (
- type(self).__name__,
- type(self.comparator).__name__,
- key)
+ "Neither %r object nor %r object has an attribute %r"
+ % (type(self).__name__, type(self.comparator).__name__, key)
)
def operate(self, op, *other, **kwargs):
@@ -697,10 +701,14 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
return op(other, self.comparator, **kwargs)
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(None, obj,
- _compared_to_operator=operator,
- type_=type_,
- _compared_to_type=self.type, unique=True)
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ type_=type_,
+ _compared_to_type=self.type,
+ unique=True,
+ )
@property
def expression(self):
@@ -713,17 +721,18 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
@util.memoized_property
def base_columns(self):
- return util.column_set(c for c in self.proxy_set
- if not hasattr(c, '_proxies'))
+ return util.column_set(
+ c for c in self.proxy_set if not hasattr(c, "_proxies")
+ )
@util.memoized_property
def proxy_set(self):
s = util.column_set([self])
- if hasattr(self, '_proxies'):
+ if hasattr(self, "_proxies"):
for c in self._proxies:
s.update(c.proxy_set)
return s
@@ -738,11 +747,15 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
"""Return True if the given column element compares to this one
when targeting within a result row."""
- return hasattr(other, 'name') and hasattr(self, 'name') and \
- other.name == self.name
+ return (
+ hasattr(other, "name")
+ and hasattr(self, "name")
+ and other.name == self.name
+ )
def _make_proxy(
- self, selectable, name=None, name_is_truncatable=False, **kw):
+ self, selectable, name=None, name_is_truncatable=False, **kw
+ ):
"""Create a new :class:`.ColumnElement` representing this
:class:`.ColumnElement` as it appears in the select list of a
descending selectable.
@@ -762,13 +775,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
key = name
co = ColumnClause(
_as_truncated(name) if name_is_truncatable else name,
- type_=getattr(self, 'type', None),
- _selectable=selectable
+ type_=getattr(self, "type", None),
+ _selectable=selectable,
)
co._proxies = [self]
if selectable._is_clone_of is not None:
- co._is_clone_of = \
- selectable._is_clone_of.columns.get(key)
+ co._is_clone_of = selectable._is_clone_of.columns.get(key)
selectable._columns[key] = co
return co
@@ -788,7 +800,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
this one via foreign key or other criterion.
"""
- to_compare = (other, )
+ to_compare = (other,)
if equivalents and other in equivalents:
to_compare = equivalents[other].union(to_compare)
@@ -838,7 +850,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
self = self._is_clone_of
return _anonymous_label(
- '%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon'))
+ "%%(%d %s)s" % (id(self), getattr(self, "name", "anon"))
)
@@ -862,18 +874,25 @@ class BindParameter(ColumnElement):
"""
- __visit_name__ = 'bindparam'
+ __visit_name__ = "bindparam"
_is_crud = False
_expanding_in_types = ()
- def __init__(self, key, value=NO_ARG, type_=None,
- unique=False, required=NO_ARG,
- quote=None, callable_=None,
- expanding=False,
- isoutparam=False,
- _compared_to_operator=None,
- _compared_to_type=None):
+ def __init__(
+ self,
+ key,
+ value=NO_ARG,
+ type_=None,
+ unique=False,
+ required=NO_ARG,
+ quote=None,
+ callable_=None,
+ expanding=False,
+ isoutparam=False,
+ _compared_to_operator=None,
+ _compared_to_type=None,
+ ):
r"""Produce a "bound expression".
The return value is an instance of :class:`.BindParameter`; this
@@ -1093,7 +1112,7 @@ class BindParameter(ColumnElement):
type_ = key.type
key = key.key
if required is NO_ARG:
- required = (value is NO_ARG and callable_ is None)
+ required = value is NO_ARG and callable_ is None
if value is NO_ARG:
value = None
@@ -1101,11 +1120,11 @@ class BindParameter(ColumnElement):
key = quoted_name(key, quote)
if unique:
- self.key = _anonymous_label('%%(%d %s)s' % (id(self), key
- or 'param'))
+ self.key = _anonymous_label(
+ "%%(%d %s)s" % (id(self), key or "param")
+ )
else:
- self.key = key or _anonymous_label('%%(%d param)s'
- % id(self))
+ self.key = key or _anonymous_label("%%(%d param)s" % id(self))
# identifying key that won't change across
# clones, used to identify the bind's logical
@@ -1114,7 +1133,7 @@ class BindParameter(ColumnElement):
# key that was passed in the first place, used to
# generate new keys
- self._orig_key = key or 'param'
+ self._orig_key = key or "param"
self.unique = unique
self.value = value
@@ -1125,9 +1144,9 @@ class BindParameter(ColumnElement):
if type_ is None:
if _compared_to_type is not None:
- self.type = \
- _compared_to_type.coerce_compared_value(
- _compared_to_operator, value)
+ self.type = _compared_to_type.coerce_compared_value(
+ _compared_to_operator, value
+ )
else:
self.type = type_api._resolve_value_to_type(value)
elif isinstance(type_, type):
@@ -1174,24 +1193,28 @@ class BindParameter(ColumnElement):
def _clone(self):
c = ClauseElement._clone(self)
if self.unique:
- c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key
- or 'param'))
+ c.key = _anonymous_label(
+ "%%(%d %s)s" % (id(c), c._orig_key or "param")
+ )
return c
def _convert_to_unique(self):
if not self.unique:
self.unique = True
self.key = _anonymous_label(
- '%%(%d %s)s' % (id(self), self._orig_key or 'param'))
+ "%%(%d %s)s" % (id(self), self._orig_key or "param")
+ )
def compare(self, other, **kw):
"""Compare this :class:`BindParameter` to the given
clause."""
- return isinstance(other, BindParameter) \
- and self.type._compare_type_affinity(other.type) \
- and self.value == other.value \
+ return (
+ isinstance(other, BindParameter)
+ and self.type._compare_type_affinity(other.type)
+ and self.value == other.value
and self.callable == other.callable
+ )
def __getstate__(self):
"""execute a deferred value for serialization purposes."""
@@ -1200,13 +1223,16 @@ class BindParameter(ColumnElement):
v = self.value
if self.callable:
v = self.callable()
- d['callable'] = None
- d['value'] = v
+ d["callable"] = None
+ d["value"] = v
return d
def __repr__(self):
- return 'BindParameter(%r, %r, type_=%r)' % (self.key,
- self.value, self.type)
+ return "BindParameter(%r, %r, type_=%r)" % (
+ self.key,
+ self.value,
+ self.type,
+ )
class TypeClause(ClauseElement):
@@ -1216,7 +1242,7 @@ class TypeClause(ClauseElement):
"""
- __visit_name__ = 'typeclause'
+ __visit_name__ = "typeclause"
def __init__(self, type):
self.type = type
@@ -1242,12 +1268,12 @@ class TextClause(Executable, ClauseElement):
"""
- __visit_name__ = 'textclause'
+ __visit_name__ = "textclause"
- _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
- _execution_options = \
- Executable._execution_options.union(
- {'autocommit': PARSE_AUTOCOMMIT})
+ _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE)
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": PARSE_AUTOCOMMIT}
+ )
_is_implicitly_boolean = False
@property
@@ -1268,24 +1294,22 @@ class TextClause(Executable, ClauseElement):
_allow_label_resolve = False
- def __init__(
- self,
- text,
- bind=None):
+ def __init__(self, text, bind=None):
self._bind = bind
self._bindparams = {}
def repl(m):
self._bindparams[m.group(1)] = BindParameter(m.group(1))
- return ':%s' % m.group(1)
+ return ":%s" % m.group(1)
# scan the string and search for bind parameter names, add them
# to the list of bindparams
self.text = self._bind_params_regex.sub(repl, text)
@classmethod
- def _create_text(self, text, bind=None, bindparams=None,
- typemap=None, autocommit=None):
+ def _create_text(
+ self, text, bind=None, bindparams=None, typemap=None, autocommit=None
+ ):
r"""Construct a new :class:`.TextClause` clause, representing
a textual SQL string directly.
@@ -1428,8 +1452,10 @@ class TextClause(Executable, ClauseElement):
if typemap:
stmt = stmt.columns(**typemap)
if autocommit is not None:
- util.warn_deprecated('autocommit on text() is deprecated. '
- 'Use .execution_options(autocommit=True)')
+ util.warn_deprecated(
+ "autocommit on text() is deprecated. "
+ "Use .execution_options(autocommit=True)"
+ )
stmt = stmt.execution_options(autocommit=autocommit)
return stmt
@@ -1513,7 +1539,8 @@ class TextClause(Executable, ClauseElement):
except KeyError:
raise exc.ArgumentError(
"This text() construct doesn't define a "
- "bound parameter named %r" % bind.key)
+ "bound parameter named %r" % bind.key
+ )
else:
new_params[existing.key] = bind
@@ -1523,11 +1550,12 @@ class TextClause(Executable, ClauseElement):
except KeyError:
raise exc.ArgumentError(
"This text() construct doesn't define a "
- "bound parameter named %r" % key)
+ "bound parameter named %r" % key
+ )
else:
new_params[key] = existing._with_value(value)
- @util.dependencies('sqlalchemy.sql.selectable')
+ @util.dependencies("sqlalchemy.sql.selectable")
def columns(self, selectable, *cols, **types):
"""Turn this :class:`.TextClause` object into a :class:`.TextAsFrom`
object that can be embedded into another statement.
@@ -1629,12 +1657,14 @@ class TextClause(Executable, ClauseElement):
for col in cols
]
keyed_input_cols = [
- ColumnClause(key, type_) for key, type_ in types.items()]
+ ColumnClause(key, type_) for key, type_ in types.items()
+ ]
return selectable.TextAsFrom(
self,
positional_input_cols + keyed_input_cols,
- positional=bool(positional_input_cols) and not keyed_input_cols)
+ positional=bool(positional_input_cols) and not keyed_input_cols,
+ )
@property
def type(self):
@@ -1651,8 +1681,9 @@ class TextClause(Executable, ClauseElement):
return self
def _copy_internals(self, clone=_clone, **kw):
- self._bindparams = dict((b.key, clone(b, **kw))
- for b in self._bindparams.values())
+ self._bindparams = dict(
+ (b.key, clone(b, **kw)) for b in self._bindparams.values()
+ )
def get_children(self, **kwargs):
return list(self._bindparams.values())
@@ -1669,7 +1700,7 @@ class Null(ColumnElement):
"""
- __visit_name__ = 'null'
+ __visit_name__ = "null"
@util.memoized_property
def type(self):
@@ -1693,7 +1724,7 @@ class False_(ColumnElement):
"""
- __visit_name__ = 'false'
+ __visit_name__ = "false"
@util.memoized_property
def type(self):
@@ -1752,7 +1783,7 @@ class True_(ColumnElement):
"""
- __visit_name__ = 'true'
+ __visit_name__ = "true"
@util.memoized_property
def type(self):
@@ -1816,23 +1847,23 @@ class ClauseList(ClauseElement):
By default, is comma-separated, such as a column listing.
"""
- __visit_name__ = 'clauselist'
+
+ __visit_name__ = "clauselist"
def __init__(self, *clauses, **kwargs):
- self.operator = kwargs.pop('operator', operators.comma_op)
- self.group = kwargs.pop('group', True)
- self.group_contents = kwargs.pop('group_contents', True)
+ self.operator = kwargs.pop("operator", operators.comma_op)
+ self.group = kwargs.pop("group", True)
+ self.group_contents = kwargs.pop("group_contents", True)
text_converter = kwargs.pop(
- '_literal_as_text',
- _expression_literal_as_text)
+ "_literal_as_text", _expression_literal_as_text
+ )
if self.group_contents:
self.clauses = [
text_converter(clause).self_group(against=self.operator)
- for clause in clauses]
+ for clause in clauses
+ ]
else:
- self.clauses = [
- text_converter(clause)
- for clause in clauses]
+ self.clauses = [text_converter(clause) for clause in clauses]
self._is_implicitly_boolean = operators.is_boolean(self.operator)
def __iter__(self):
@@ -1847,8 +1878,9 @@ class ClauseList(ClauseElement):
def append(self, clause):
if self.group_contents:
- self.clauses.append(_literal_as_text(clause).
- self_group(against=self.operator))
+ self.clauses.append(
+ _literal_as_text(clause).self_group(against=self.operator)
+ )
else:
self.clauses.append(_literal_as_text(clause))
@@ -1875,14 +1907,18 @@ class ClauseList(ClauseElement):
"""
if not isinstance(other, ClauseList) and len(self.clauses) == 1:
return self.clauses[0].compare(other, **kw)
- elif isinstance(other, ClauseList) and \
- len(self.clauses) == len(other.clauses) and \
- self.operator is other.operator:
+ elif (
+ isinstance(other, ClauseList)
+ and len(self.clauses) == len(other.clauses)
+ and self.operator is other.operator
+ ):
if self.operator in (operators.and_, operators.or_):
completed = set()
for clause in self.clauses:
- for other_clause in set(other.clauses).difference(completed):
+ for other_clause in set(other.clauses).difference(
+ completed
+ ):
if clause.compare(other_clause, **kw):
completed.add(other_clause)
break
@@ -1898,11 +1934,12 @@ class ClauseList(ClauseElement):
class BooleanClauseList(ClauseList, ColumnElement):
- __visit_name__ = 'clauselist'
+ __visit_name__ = "clauselist"
def __init__(self, *arg, **kw):
raise NotImplementedError(
- "BooleanClauseList has a private constructor")
+ "BooleanClauseList has a private constructor"
+ )
@classmethod
def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
@@ -1910,8 +1947,7 @@ class BooleanClauseList(ClauseList, ColumnElement):
clauses = [
_expression_literal_as_text(clause)
- for clause in
- util.coerce_generator_arg(clauses)
+ for clause in util.coerce_generator_arg(clauses)
]
for clause in clauses:
@@ -1927,8 +1963,9 @@ class BooleanClauseList(ClauseList, ColumnElement):
elif not convert_clauses and clauses:
return clauses[0].self_group(against=operators._asbool)
- convert_clauses = [c.self_group(against=operator)
- for c in convert_clauses]
+ convert_clauses = [
+ c.self_group(against=operator) for c in convert_clauses
+ ]
self = cls.__new__(cls)
self.clauses = convert_clauses
@@ -2014,7 +2051,7 @@ class BooleanClauseList(ClauseList, ColumnElement):
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
def self_group(self, against=None):
if not self.clauses:
@@ -2056,22 +2093,31 @@ class Tuple(ClauseList, ColumnElement):
clauses = [_literal_as_binds(c) for c in clauses]
self._type_tuple = [arg.type for arg in clauses]
- self.type = kw.pop('type_', self._type_tuple[0]
- if self._type_tuple else type_api.NULLTYPE)
+ self.type = kw.pop(
+ "type_",
+ self._type_tuple[0] if self._type_tuple else type_api.NULLTYPE,
+ )
super(Tuple, self).__init__(*clauses, **kw)
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
def _bind_param(self, operator, obj, type_=None):
- return Tuple(*[
- BindParameter(None, o, _compared_to_operator=operator,
- _compared_to_type=compared_to_type, unique=True,
- type_=type_)
- for o, compared_to_type in zip(obj, self._type_tuple)
- ]).self_group()
+ return Tuple(
+ *[
+ BindParameter(
+ None,
+ o,
+ _compared_to_operator=operator,
+ _compared_to_type=compared_to_type,
+ unique=True,
+ type_=type_,
+ )
+ for o, compared_to_type in zip(obj, self._type_tuple)
+ ]
+ ).self_group()
class Case(ColumnElement):
@@ -2101,7 +2147,7 @@ class Case(ColumnElement):
"""
- __visit_name__ = 'case'
+ __visit_name__ = "case"
def __init__(self, whens, value=None, else_=None):
r"""Produce a ``CASE`` expression.
@@ -2231,13 +2277,13 @@ class Case(ColumnElement):
if value is not None:
whenlist = [
- (_literal_as_binds(c).self_group(),
- _literal_as_binds(r)) for (c, r) in whens
+ (_literal_as_binds(c).self_group(), _literal_as_binds(r))
+ for (c, r) in whens
]
else:
whenlist = [
- (_no_literals(c).self_group(),
- _literal_as_binds(r)) for (c, r) in whens
+ (_no_literals(c).self_group(), _literal_as_binds(r))
+ for (c, r) in whens
]
if whenlist:
@@ -2260,8 +2306,7 @@ class Case(ColumnElement):
def _copy_internals(self, clone=_clone, **kw):
if self.value is not None:
self.value = clone(self.value, **kw)
- self.whens = [(clone(x, **kw), clone(y, **kw))
- for x, y in self.whens]
+ self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens]
if self.else_ is not None:
self.else_ = clone(self.else_, **kw)
@@ -2276,8 +2321,9 @@ class Case(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(*[x._from_objects for x in
- self.get_children()]))
+ return list(
+ itertools.chain(*[x._from_objects for x in self.get_children()])
+ )
def literal_column(text, type_=None):
@@ -2333,7 +2379,7 @@ class Cast(ColumnElement):
"""
- __visit_name__ = 'cast'
+ __visit_name__ = "cast"
def __init__(self, expression, type_):
"""Produce a ``CAST`` expression.
@@ -2416,7 +2462,7 @@ class TypeCoerce(ColumnElement):
"""
- __visit_name__ = 'type_coerce'
+ __visit_name__ = "type_coerce"
def __init__(self, expression, type_):
"""Associate a SQL expression with a particular type, without rendering
@@ -2484,10 +2530,10 @@ class TypeCoerce(ColumnElement):
def _copy_internals(self, clone=_clone, **kw):
self.clause = clone(self.clause, **kw)
- self.__dict__.pop('typed_expression', None)
+ self.__dict__.pop("typed_expression", None)
def get_children(self, **kwargs):
- return self.clause,
+ return (self.clause,)
@property
def _from_objects(self):
@@ -2506,7 +2552,7 @@ class TypeCoerce(ColumnElement):
class Extract(ColumnElement):
"""Represent a SQL EXTRACT clause, ``extract(field FROM expr)``."""
- __visit_name__ = 'extract'
+ __visit_name__ = "extract"
def __init__(self, field, expr, **kwargs):
"""Return a :class:`.Extract` construct.
@@ -2524,7 +2570,7 @@ class Extract(ColumnElement):
self.expr = clone(self.expr, **kw)
def get_children(self, **kwargs):
- return self.expr,
+ return (self.expr,)
@property
def _from_objects(self):
@@ -2543,7 +2589,8 @@ class _label_reference(ColumnElement):
within an OVER clause.
"""
- __visit_name__ = 'label_reference'
+
+ __visit_name__ = "label_reference"
def __init__(self, element):
self.element = element
@@ -2557,7 +2604,7 @@ class _label_reference(ColumnElement):
class _textual_label_reference(ColumnElement):
- __visit_name__ = 'textual_label_reference'
+ __visit_name__ = "textual_label_reference"
def __init__(self, element):
self.element = element
@@ -2580,14 +2627,23 @@ class UnaryExpression(ColumnElement):
:func:`.nullsfirst` and :func:`.nullslast`.
"""
- __visit_name__ = 'unary'
- def __init__(self, element, operator=None, modifier=None,
- type_=None, negate=None, wraps_column_expression=False):
+ __visit_name__ = "unary"
+
+ def __init__(
+ self,
+ element,
+ operator=None,
+ modifier=None,
+ type_=None,
+ negate=None,
+ wraps_column_expression=False,
+ ):
self.operator = operator
self.modifier = modifier
self.element = element.self_group(
- against=self.operator or self.modifier)
+ against=self.operator or self.modifier
+ )
self.type = type_api.to_instance(type_)
self.negate = negate
self.wraps_column_expression = wraps_column_expression
@@ -2633,7 +2689,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.nullsfirst_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_nullslast(cls, column):
@@ -2675,7 +2732,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.nullslast_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_desc(cls, column):
@@ -2715,7 +2773,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.desc_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_asc(cls, column):
@@ -2754,7 +2813,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.asc_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_distinct(cls, expr):
@@ -2794,8 +2854,11 @@ class UnaryExpression(ColumnElement):
"""
expr = _literal_as_binds(expr)
return UnaryExpression(
- expr, operator=operators.distinct_op,
- type_=expr.type, wraps_column_expression=False)
+ expr,
+ operator=operators.distinct_op,
+ type_=expr.type,
+ wraps_column_expression=False,
+ )
@property
def _order_by_label_element(self):
@@ -2812,17 +2875,17 @@ class UnaryExpression(ColumnElement):
self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def compare(self, other, **kw):
"""Compare this :class:`UnaryExpression` against the given
:class:`.ClauseElement`."""
return (
- isinstance(other, UnaryExpression) and
- self.operator == other.operator and
- self.modifier == other.modifier and
- self.element.compare(other.element, **kw)
+ isinstance(other, UnaryExpression)
+ and self.operator == other.operator
+ and self.modifier == other.modifier
+ and self.element.compare(other.element, **kw)
)
def _negate(self):
@@ -2833,14 +2896,16 @@ class UnaryExpression(ColumnElement):
negate=self.operator,
modifier=self.modifier,
type_=self.type,
- wraps_column_expression=self.wraps_column_expression)
+ wraps_column_expression=self.wraps_column_expression,
+ )
elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
type_=type_api.BOOLEANTYPE,
wraps_column_expression=self.wraps_column_expression,
- negate=None)
+ negate=None,
+ )
else:
return ClauseElement._negate(self)
@@ -2860,6 +2925,7 @@ class CollectionAggregate(UnaryExpression):
MySQL, they only work for subqueries.
"""
+
@classmethod
def _create_any(cls, expr):
"""Produce an ANY expression.
@@ -2883,12 +2949,15 @@ class CollectionAggregate(UnaryExpression):
expr = _literal_as_binds(expr)
- if expr.is_selectable and hasattr(expr, 'as_scalar'):
+ if expr.is_selectable and hasattr(expr, "as_scalar"):
expr = expr.as_scalar()
expr = expr.self_group()
return CollectionAggregate(
- expr, operator=operators.any_op,
- type_=type_api.NULLTYPE, wraps_column_expression=False)
+ expr,
+ operator=operators.any_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
@classmethod
def _create_all(cls, expr):
@@ -2912,12 +2981,15 @@ class CollectionAggregate(UnaryExpression):
"""
expr = _literal_as_binds(expr)
- if expr.is_selectable and hasattr(expr, 'as_scalar'):
+ if expr.is_selectable and hasattr(expr, "as_scalar"):
expr = expr.as_scalar()
expr = expr.self_group()
return CollectionAggregate(
- expr, operator=operators.all_op,
- type_=type_api.NULLTYPE, wraps_column_expression=False)
+ expr,
+ operator=operators.all_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
# operate and reverse_operate are hardwired to
# dispatch onto the type comparator directly, so that we can
@@ -2925,19 +2997,20 @@ class CollectionAggregate(UnaryExpression):
def operate(self, op, *other, **kwargs):
if not operators.is_comparison(op):
raise exc.ArgumentError(
- "Only comparison operators may be used with ANY/ALL")
- kwargs['reverse'] = True
+ "Only comparison operators may be used with ANY/ALL"
+ )
+ kwargs["reverse"] = True
return self.comparator.operate(operators.mirror(op), *other, **kwargs)
def reverse_operate(self, op, other, **kwargs):
# comparison operators should never call reverse_operate
assert not operators.is_comparison(op)
raise exc.ArgumentError(
- "Only comparison operators may be used with ANY/ALL")
+ "Only comparison operators may be used with ANY/ALL"
+ )
class AsBoolean(UnaryExpression):
-
def __init__(self, element, operator, negate):
self.element = element
self.type = type_api.BOOLEANTYPE
@@ -2971,7 +3044,7 @@ class BinaryExpression(ColumnElement):
"""
- __visit_name__ = 'binary'
+ __visit_name__ = "binary"
_is_implicitly_boolean = True
"""Indicates that any database will know this is a boolean expression
@@ -2979,8 +3052,9 @@ class BinaryExpression(ColumnElement):
"""
- def __init__(self, left, right, operator, type_=None,
- negate=None, modifiers=None):
+ def __init__(
+ self, left, right, operator, type_=None, negate=None, modifiers=None
+ ):
# allow compatibility with libraries that
# refer to BinaryExpression directly and pass strings
if isinstance(operator, util.string_types):
@@ -3026,15 +3100,15 @@ class BinaryExpression(ColumnElement):
given :class:`BinaryExpression`."""
return (
- isinstance(other, BinaryExpression) and
- self.operator == other.operator and
- (
- self.left.compare(other.left, **kw) and
- self.right.compare(other.right, **kw) or
- (
- operators.is_commutative(self.operator) and
- self.left.compare(other.right, **kw) and
- self.right.compare(other.left, **kw)
+ isinstance(other, BinaryExpression)
+ and self.operator == other.operator
+ and (
+ self.left.compare(other.left, **kw)
+ and self.right.compare(other.right, **kw)
+ or (
+ operators.is_commutative(self.operator)
+ and self.left.compare(other.right, **kw)
+ and self.right.compare(other.left, **kw)
)
)
)
@@ -3053,7 +3127,8 @@ class BinaryExpression(ColumnElement):
self.negate,
negate=self.operator,
type_=self.type,
- modifiers=self.modifiers)
+ modifiers=self.modifiers,
+ )
else:
return super(BinaryExpression, self)._negate()
@@ -3065,7 +3140,8 @@ class Slice(ColumnElement):
may be interpreted by specific dialects, e.g. PostgreSQL.
"""
- __visit_name__ = 'slice'
+
+ __visit_name__ = "slice"
def __init__(self, start, stop, step):
self.start = start
@@ -3081,17 +3157,18 @@ class Slice(ColumnElement):
class IndexExpression(BinaryExpression):
"""Represent the class of expressions that are like an "index" operation.
"""
+
pass
class Grouping(ColumnElement):
"""Represent a grouping within a column expression"""
- __visit_name__ = 'grouping'
+ __visit_name__ = "grouping"
def __init__(self, element):
self.element = element
- self.type = getattr(element, 'type', type_api.NULLTYPE)
+ self.type = getattr(element, "type", type_api.NULLTYPE)
def self_group(self, against=None):
return self
@@ -3106,13 +3183,13 @@ class Grouping(ColumnElement):
@property
def _label(self):
- return getattr(self.element, '_label', None) or self.anon_label
+ return getattr(self.element, "_label", None) or self.anon_label
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
@property
def _from_objects(self):
@@ -3122,15 +3199,16 @@ class Grouping(ColumnElement):
return getattr(self.element, attr)
def __getstate__(self):
- return {'element': self.element, 'type': self.type}
+ return {"element": self.element, "type": self.type}
def __setstate__(self, state):
- self.element = state['element']
- self.type = state['type']
+ self.element = state["element"]
+ self.type = state["type"]
def compare(self, other, **kw):
- return isinstance(other, Grouping) and \
- self.element.compare(other.element)
+ return isinstance(other, Grouping) and self.element.compare(
+ other.element
+ )
RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
@@ -3147,14 +3225,15 @@ class Over(ColumnElement):
backends.
"""
- __visit_name__ = 'over'
+
+ __visit_name__ = "over"
order_by = None
partition_by = None
def __init__(
- self, element, partition_by=None,
- order_by=None, range_=None, rows=None):
+ self, element, partition_by=None, order_by=None, range_=None, rows=None
+ ):
"""Produce an :class:`.Over` object against a function.
Used against aggregate or so-called "window" functions,
@@ -3237,17 +3316,20 @@ class Over(ColumnElement):
if order_by is not None:
self.order_by = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
if partition_by is not None:
self.partition_by = ClauseList(
*util.to_list(partition_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
if range_:
self.range_ = self._interpret_range(range_)
if rows:
raise exc.ArgumentError(
- "'range_' and 'rows' are mutually exclusive")
+ "'range_' and 'rows' are mutually exclusive"
+ )
else:
self.rows = None
elif rows:
@@ -3267,7 +3349,8 @@ class Over(ColumnElement):
lower = int(range_[0])
except ValueError:
raise exc.ArgumentError(
- "Integer or None expected for range value")
+ "Integer or None expected for range value"
+ )
else:
if lower == 0:
lower = RANGE_CURRENT
@@ -3279,7 +3362,8 @@ class Over(ColumnElement):
upper = int(range_[1])
except ValueError:
raise exc.ArgumentError(
- "Integer or None expected for range value")
+ "Integer or None expected for range value"
+ )
else:
if upper == 0:
upper = RANGE_CURRENT
@@ -3303,9 +3387,11 @@ class Over(ColumnElement):
return self.element.type
def get_children(self, **kwargs):
- return [c for c in
- (self.element, self.partition_by, self.order_by)
- if c is not None]
+ return [
+ c
+ for c in (self.element, self.partition_by, self.order_by)
+ if c is not None
+ ]
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -3316,11 +3402,15 @@ class Over(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in
- (self.element, self.partition_by, self.order_by)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.partition_by, self.order_by)
+ if c is not None
+ ]
+ )
+ )
class WithinGroup(ColumnElement):
@@ -3339,7 +3429,8 @@ class WithinGroup(ColumnElement):
``None``, the function's ``.type`` is used.
"""
- __visit_name__ = 'withingroup'
+
+ __visit_name__ = "withingroup"
order_by = None
@@ -3383,7 +3474,8 @@ class WithinGroup(ColumnElement):
if order_by is not None:
self.order_by = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
def over(self, partition_by=None, order_by=None, range_=None, rows=None):
"""Produce an OVER clause against this :class:`.WithinGroup`
@@ -3394,8 +3486,12 @@ class WithinGroup(ColumnElement):
"""
return Over(
- self, partition_by=partition_by, order_by=order_by,
- range_=range_, rows=rows)
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
@util.memoized_property
def type(self):
@@ -3406,9 +3502,7 @@ class WithinGroup(ColumnElement):
return self.element.type
def get_children(self, **kwargs):
- return [c for c in
- (self.element, self.order_by)
- if c is not None]
+ return [c for c in (self.element, self.order_by) if c is not None]
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -3417,11 +3511,15 @@ class WithinGroup(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in
- (self.element, self.order_by)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.order_by)
+ if c is not None
+ ]
+ )
+ )
class FunctionFilter(ColumnElement):
@@ -3443,7 +3541,8 @@ class FunctionFilter(ColumnElement):
:meth:`.FunctionElement.filter`
"""
- __visit_name__ = 'funcfilter'
+
+ __visit_name__ = "funcfilter"
criterion = None
@@ -3515,17 +3614,19 @@ class FunctionFilter(ColumnElement):
"""
return Over(
- self, partition_by=partition_by, order_by=order_by,
- range_=range_, rows=rows)
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
@util.memoized_property
def type(self):
return self.func.type
def get_children(self, **kwargs):
- return [c for c in
- (self.func, self.criterion)
- if c is not None]
+ return [c for c in (self.func, self.criterion) if c is not None]
def _copy_internals(self, clone=_clone, **kw):
self.func = clone(self.func, **kw)
@@ -3534,10 +3635,15 @@ class FunctionFilter(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in (self.func, self.criterion)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.func, self.criterion)
+ if c is not None
+ ]
+ )
+ )
class Label(ColumnElement):
@@ -3548,7 +3654,7 @@ class Label(ColumnElement):
"""
- __visit_name__ = 'label'
+ __visit_name__ = "label"
def __init__(self, name, element, type_=None):
"""Return a :class:`Label` object for the
@@ -3577,7 +3683,7 @@ class Label(ColumnElement):
self._resolve_label = self.name
else:
self.name = _anonymous_label(
- '%%(%d %s)s' % (id(self), getattr(element, 'name', 'anon'))
+ "%%(%d %s)s" % (id(self), getattr(element, "name", "anon"))
)
self.key = self._label = self._key_label = self.name
@@ -3603,7 +3709,7 @@ class Label(ColumnElement):
@util.memoized_property
def type(self):
return type_api.to_instance(
- self._type or getattr(self._element, 'type', None)
+ self._type or getattr(self._element, "type", None)
)
@util.memoized_property
@@ -3619,9 +3725,7 @@ class Label(ColumnElement):
def _apply_to_inner(self, fn, *arg, **kw):
sub_element = fn(*arg, **kw)
if sub_element is not self._element:
- return Label(self.name,
- sub_element,
- type_=self._type)
+ return Label(self.name, sub_element, type_=self._type)
else:
return self
@@ -3634,16 +3738,16 @@ class Label(ColumnElement):
return self.element.foreign_keys
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
self._element = clone(self._element, **kw)
- self.__dict__.pop('element', None)
- self.__dict__.pop('_allow_label_resolve', None)
+ self.__dict__.pop("element", None)
+ self.__dict__.pop("_allow_label_resolve", None)
if anonymize_labels:
self.name = self._resolve_label = _anonymous_label(
- '%%(%d %s)s' % (
- id(self), getattr(self.element, 'name', 'anon'))
+ "%%(%d %s)s"
+ % (id(self), getattr(self.element, "name", "anon"))
)
self.key = self._label = self._key_label = self.name
@@ -3652,8 +3756,9 @@ class Label(ColumnElement):
return self.element._from_objects
def _make_proxy(self, selectable, name=None, **kw):
- e = self.element._make_proxy(selectable,
- name=name if name else self.name)
+ e = self.element._make_proxy(
+ selectable, name=name if name else self.name
+ )
e._proxies.append(self)
if self._type is not None:
e.type = self._type
@@ -3694,7 +3799,8 @@ class ColumnClause(Immutable, ColumnElement):
:class:`.Column`
"""
- __visit_name__ = 'column'
+
+ __visit_name__ = "column"
onupdate = default = server_default = server_onupdate = None
@@ -3792,25 +3898,33 @@ class ColumnClause(Immutable, ColumnElement):
self.is_literal = is_literal
def _compare_name_for_result(self, other):
- if self.is_literal or \
- self.table is None or self.table._textual or \
- not hasattr(other, 'proxy_set') or (
- isinstance(other, ColumnClause) and
- (other.is_literal or
- other.table is None or
- other.table._textual)
- ):
- return (hasattr(other, 'name') and self.name == other.name) or \
- (hasattr(other, '_label') and self._label == other._label)
+ if (
+ self.is_literal
+ or self.table is None
+ or self.table._textual
+ or not hasattr(other, "proxy_set")
+ or (
+ isinstance(other, ColumnClause)
+ and (
+ other.is_literal
+ or other.table is None
+ or other.table._textual
+ )
+ )
+ ):
+ return (hasattr(other, "name") and self.name == other.name) or (
+ hasattr(other, "_label") and self._label == other._label
+ )
else:
return other.proxy_set.intersection(self.proxy_set)
def _get_table(self):
- return self.__dict__['table']
+ return self.__dict__["table"]
def _set_table(self, table):
self._memoized_property.expire_instance(self)
- self.__dict__['table'] = table
+ self.__dict__["table"] = table
+
table = property(_get_table, _set_table)
@_memoized_property
@@ -3826,7 +3940,7 @@ class ColumnClause(Immutable, ColumnElement):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
@_memoized_property
def _key_label(self):
@@ -3850,9 +3964,8 @@ class ColumnClause(Immutable, ColumnElement):
return None
elif t is not None and t.named_with_column:
- if getattr(t, 'schema', None):
- label = t.schema.replace('.', '_') + "_" + \
- t.name + "_" + name
+ if getattr(t, "schema", None):
+ label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
else:
label = t.name + "_" + name
@@ -3884,31 +3997,39 @@ class ColumnClause(Immutable, ColumnElement):
return name
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(self.key, obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type,
- type_=type_,
- unique=True)
-
- def _make_proxy(self, selectable, name=None, attach=True,
- name_is_truncatable=False, **kw):
+ return BindParameter(
+ self.key,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ )
+
+ def _make_proxy(
+ self,
+ selectable,
+ name=None,
+ attach=True,
+ name_is_truncatable=False,
+ **kw
+ ):
# propagate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
c = self._constructor(
- _as_truncated(name or self.name) if
- name_is_truncatable else
- (name or self.name),
+ _as_truncated(name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
type_=self.type,
_selectable=selectable,
- is_literal=is_literal
+ is_literal=is_literal,
)
if name is None:
c.key = self.key
c._proxies = [self]
if selectable._is_clone_of is not None:
- c._is_clone_of = \
- selectable._is_clone_of.columns.get(c.key)
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
if attach:
selectable._columns[c.key] = c
@@ -3924,24 +4045,25 @@ class CollationClause(ColumnElement):
class _IdentifiedClause(Executable, ClauseElement):
- __visit_name__ = 'identified'
- _execution_options = \
- Executable._execution_options.union({'autocommit': False})
+ __visit_name__ = "identified"
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": False}
+ )
def __init__(self, ident):
self.ident = ident
class SavepointClause(_IdentifiedClause):
- __visit_name__ = 'savepoint'
+ __visit_name__ = "savepoint"
class RollbackToSavepointClause(_IdentifiedClause):
- __visit_name__ = 'rollback_to_savepoint'
+ __visit_name__ = "rollback_to_savepoint"
class ReleaseSavepointClause(_IdentifiedClause):
- __visit_name__ = 'release_savepoint'
+ __visit_name__ = "release_savepoint"
class quoted_name(util.MemoizedSlots, util.text_type):
@@ -3992,7 +4114,7 @@ class quoted_name(util.MemoizedSlots, util.text_type):
"""
- __slots__ = 'quote', 'lower', 'upper'
+ __slots__ = "quote", "lower", "upper"
def __new__(cls, value, quote):
if value is None:
@@ -4026,9 +4148,9 @@ class quoted_name(util.MemoizedSlots, util.text_type):
return util.text_type(self).upper()
def __repr__(self):
- backslashed = self.encode('ascii', 'backslashreplace')
+ backslashed = self.encode("ascii", "backslashreplace")
if not util.py2k:
- backslashed = backslashed.decode('ascii')
+ backslashed = backslashed.decode("ascii")
return "'%s'" % backslashed
@@ -4094,6 +4216,7 @@ class conv(_truncated_label):
:ref:`constraint_naming_conventions`
"""
+
__slots__ = ()
@@ -4102,6 +4225,7 @@ class _defer_name(_truncated_label):
generation.
"""
+
__slots__ = ()
def __new__(cls, value):
@@ -4113,13 +4237,15 @@ class _defer_name(_truncated_label):
return super(_defer_name, cls).__new__(cls, value)
def __reduce__(self):
- return self.__class__, (util.text_type(self), )
+ return self.__class__, (util.text_type(self),)
class _defer_none_name(_defer_name):
"""indicate a 'deferred' name that was ultimately the value None."""
+
__slots__ = ()
+
_NONE_NAME = _defer_none_name("_unnamed_")
# for backwards compatibility in case
@@ -4138,15 +4264,15 @@ class _anonymous_label(_truncated_label):
def __add__(self, other):
return _anonymous_label(
quoted_name(
- util.text_type.__add__(self, util.text_type(other)),
- self.quote)
+ util.text_type.__add__(self, util.text_type(other)), self.quote
+ )
)
def __radd__(self, other):
return _anonymous_label(
quoted_name(
- util.text_type.__add__(util.text_type(other), self),
- self.quote)
+ util.text_type.__add__(util.text_type(other), self), self.quote
+ )
)
def apply_map(self, map_):
@@ -4206,20 +4332,23 @@ def _cloned_intersection(a, b):
"""
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
- return set(elem for elem in a
- if all_overlap.intersection(elem._cloned_set))
+ return set(
+ elem for elem in a if all_overlap.intersection(elem._cloned_set)
+ )
def _cloned_difference(a, b):
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
- return set(elem for elem in a
- if not all_overlap.intersection(elem._cloned_set))
+ return set(
+ elem for elem in a if not all_overlap.intersection(elem._cloned_set)
+ )
@util.dependencies("sqlalchemy.sql.functions")
def _labeled(functions, element):
- if not hasattr(element, 'name') or \
- isinstance(element, functions.FunctionElement):
+ if not hasattr(element, "name") or isinstance(
+ element, functions.FunctionElement
+ ):
return element.label(None)
else:
return element
@@ -4235,7 +4364,7 @@ def _find_columns(clause):
"""locate Column objects within the given expression."""
cols = util.column_set()
- traverse(clause, {}, {'column': cols.add})
+ traverse(clause, {}, {"column": cols.add})
return cols
@@ -4253,7 +4382,7 @@ def _find_columns(clause):
def _column_as_key(element):
if isinstance(element, util.string_types):
return element
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
try:
return element.key
@@ -4262,7 +4391,7 @@ def _column_as_key(element):
def _clause_element_as_expr(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
else:
return element
@@ -4272,7 +4401,7 @@ def _literal_as_label_reference(element):
if isinstance(element, util.string_types):
return _textual_label_reference(element)
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
element = element.__clause_element__()
return _literal_as_text(element)
@@ -4282,11 +4411,13 @@ def _literal_and_labels_as_label_reference(element):
if isinstance(element, util.string_types):
return _textual_label_reference(element)
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
element = element.__clause_element__()
- if isinstance(element, ColumnElement) and \
- element._order_by_label_element is not None:
+ if (
+ isinstance(element, ColumnElement)
+ and element._order_by_label_element is not None
+ ):
return _label_reference(element)
else:
return _literal_as_text(element)
@@ -4299,14 +4430,15 @@ def _expression_literal_as_text(element):
def _literal_as_text(element, warn=False):
if isinstance(element, Visitable):
return element
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif isinstance(element, util.string_types):
if warn:
util.warn_limited(
"Textual SQL expression %(expr)r should be "
"explicitly declared as text(%(expr)r)",
- {"expr": util.ellipses_string(element)})
+ {"expr": util.ellipses_string(element)},
+ )
return TextClause(util.text_type(element))
elif isinstance(element, (util.NoneType, bool)):
@@ -4319,20 +4451,23 @@ def _literal_as_text(element, warn=False):
def _no_literals(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif not isinstance(element, Visitable):
- raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' "
- "function to indicate a SQL expression "
- "literal, or 'literal()' to indicate a "
- "bound value." % (element, ))
+ raise exc.ArgumentError(
+ "Ambiguous literal: %r. Use the 'text()' "
+ "function to indicate a SQL expression "
+ "literal, or 'literal()' to indicate a "
+ "bound value." % (element,)
+ )
else:
return element
def _is_literal(element):
- return not isinstance(element, Visitable) and \
- not hasattr(element, '__clause_element__')
+ return not isinstance(element, Visitable) and not hasattr(
+ element, "__clause_element__"
+ )
def _only_column_elements_or_none(element, name):
@@ -4343,17 +4478,18 @@ def _only_column_elements_or_none(element, name):
def _only_column_elements(element, name):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
if not isinstance(element, ColumnElement):
raise exc.ArgumentError(
"Column-based expression object expected for argument "
- "'%s'; got: '%s', type %s" % (name, element, type(element)))
+ "'%s'; got: '%s', type %s" % (name, element, type(element))
+ )
return element
def _literal_as_binds(element, name=None, type_=None):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif not isinstance(element, Visitable):
if element is None:
@@ -4363,13 +4499,14 @@ def _literal_as_binds(element, name=None, type_=None):
else:
return element
-_guess_straight_column = re.compile(r'^\w\S*$', re.I)
+
+_guess_straight_column = re.compile(r"^\w\S*$", re.I)
def _interpret_as_column_or_from(element):
if isinstance(element, Visitable):
return element
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
insp = inspection.inspect(element, raiseerr=False)
@@ -4399,11 +4536,11 @@ def _interpret_as_column_or_from(element):
{
"column": util.ellipses_string(element),
"literal_column": "literal_column"
- if guess_is_literal else "column"
- })
- return ColumnClause(
- element,
- is_literal=guess_is_literal)
+ if guess_is_literal
+ else "column",
+ },
+ )
+ return ColumnClause(element, is_literal=guess_is_literal)
def _const_expr(element):
@@ -4416,9 +4553,7 @@ def _const_expr(element):
elif element is True:
return True_()
else:
- raise exc.ArgumentError(
- "Expected None, False, or True"
- )
+ raise exc.ArgumentError("Expected None, False, or True")
def _type_from_args(args):
@@ -4429,18 +4564,15 @@ def _type_from_args(args):
return type_api.NULLTYPE
-def _corresponding_column_or_error(fromclause, column,
- require_embedded=False):
- c = fromclause.corresponding_column(column,
- require_embedded=require_embedded)
+def _corresponding_column_or_error(fromclause, column, require_embedded=False):
+ c = fromclause.corresponding_column(
+ column, require_embedded=require_embedded
+ )
if c is None:
raise exc.InvalidRequestError(
"Given column '%s', attached to table '%s', "
"failed to locate a corresponding column from table '%s'"
- %
- (column,
- getattr(column, 'table', None),
- fromclause.description)
+ % (column, getattr(column, "table", None), fromclause.description)
)
return c
@@ -4449,7 +4581,7 @@ class AnnotatedColumnElement(Annotated):
def __init__(self, element, values):
Annotated.__init__(self, element, values)
ColumnElement.comparator._reset(self)
- for attr in ('name', 'key', 'table'):
+ for attr in ("name", "key", "table"):
if self.__dict__.get(attr, False) is None:
self.__dict__.pop(attr)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index b69b6ee8c..aab9f46d4 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -15,43 +15,142 @@ class.
"""
__all__ = [
- 'Alias', 'any_', 'all_', 'ClauseElement', 'ColumnCollection', 'ColumnElement',
- 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Lateral',
- 'Select',
- 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between',
- 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct',
- 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
- 'collate', 'insert', 'intersect', 'intersect_all', 'join', 'label',
- 'lateral', 'literal', 'literal_column', 'not_', 'null', 'nullsfirst',
- 'nullslast',
- 'or_', 'outparam', 'outerjoin', 'over', 'select', 'subquery',
- 'table', 'text',
- 'tuple_', 'type_coerce', 'quoted_name', 'union', 'union_all', 'update',
- 'within_group',
- 'TableSample', 'tablesample']
+ "Alias",
+ "any_",
+ "all_",
+ "ClauseElement",
+ "ColumnCollection",
+ "ColumnElement",
+ "CompoundSelect",
+ "Delete",
+ "FromClause",
+ "Insert",
+ "Join",
+ "Lateral",
+ "Select",
+ "Selectable",
+ "TableClause",
+ "Update",
+ "alias",
+ "and_",
+ "asc",
+ "between",
+ "bindparam",
+ "case",
+ "cast",
+ "column",
+ "delete",
+ "desc",
+ "distinct",
+ "except_",
+ "except_all",
+ "exists",
+ "extract",
+ "func",
+ "modifier",
+ "collate",
+ "insert",
+ "intersect",
+ "intersect_all",
+ "join",
+ "label",
+ "lateral",
+ "literal",
+ "literal_column",
+ "not_",
+ "null",
+ "nullsfirst",
+ "nullslast",
+ "or_",
+ "outparam",
+ "outerjoin",
+ "over",
+ "select",
+ "subquery",
+ "table",
+ "text",
+ "tuple_",
+ "type_coerce",
+ "quoted_name",
+ "union",
+ "union_all",
+ "update",
+ "within_group",
+ "TableSample",
+ "tablesample",
+]
from .visitors import Visitable
from .functions import func, modifier, FunctionElement, Function
from ..util.langhelpers import public_factory
-from .elements import ClauseElement, ColumnElement,\
- BindParameter, CollectionAggregate, UnaryExpression, BooleanClauseList, \
- Label, Cast, Case, ColumnClause, TextClause, Over, Null, \
- True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \
- Grouping, WithinGroup, not_, quoted_name, \
- collate, literal_column, between,\
- literal, outparam, TypeCoerce, ClauseList, FunctionFilter
+from .elements import (
+ ClauseElement,
+ ColumnElement,
+ BindParameter,
+ CollectionAggregate,
+ UnaryExpression,
+ BooleanClauseList,
+ Label,
+ Cast,
+ Case,
+ ColumnClause,
+ TextClause,
+ Over,
+ Null,
+ True_,
+ False_,
+ BinaryExpression,
+ Tuple,
+ TypeClause,
+ Extract,
+ Grouping,
+ WithinGroup,
+ not_,
+ quoted_name,
+ collate,
+ literal_column,
+ between,
+ literal,
+ outparam,
+ TypeCoerce,
+ ClauseList,
+ FunctionFilter,
+)
-from .elements import SavepointClause, RollbackToSavepointClause, \
- ReleaseSavepointClause
+from .elements import (
+ SavepointClause,
+ RollbackToSavepointClause,
+ ReleaseSavepointClause,
+)
-from .base import ColumnCollection, Generative, Executable, \
- PARSE_AUTOCOMMIT
+from .base import ColumnCollection, Generative, Executable, PARSE_AUTOCOMMIT
-from .selectable import Alias, Join, Select, Selectable, TableClause, \
- CompoundSelect, CTE, FromClause, FromGrouping, Lateral, SelectBase, \
- alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \
- lateral, Exists, ScalarSelect, TextAsFrom, TableSample, tablesample
+from .selectable import (
+ Alias,
+ Join,
+ Select,
+ Selectable,
+ TableClause,
+ CompoundSelect,
+ CTE,
+ FromClause,
+ FromGrouping,
+ Lateral,
+ SelectBase,
+ alias,
+ GenerativeSelect,
+ subquery,
+ HasCTE,
+ HasPrefixes,
+ HasSuffixes,
+ lateral,
+ Exists,
+ ScalarSelect,
+ TextAsFrom,
+ TableSample,
+ tablesample,
+)
from .dml import Insert, Update, Delete, UpdateBase, ValuesBase
@@ -79,23 +178,30 @@ extract = public_factory(Extract, ".expression.extract")
tuple_ = public_factory(Tuple, ".expression.tuple_")
except_ = public_factory(CompoundSelect._create_except, ".expression.except_")
except_all = public_factory(
- CompoundSelect._create_except_all, ".expression.except_all")
+ CompoundSelect._create_except_all, ".expression.except_all"
+)
intersect = public_factory(
- CompoundSelect._create_intersect, ".expression.intersect")
+ CompoundSelect._create_intersect, ".expression.intersect"
+)
intersect_all = public_factory(
- CompoundSelect._create_intersect_all, ".expression.intersect_all")
+ CompoundSelect._create_intersect_all, ".expression.intersect_all"
+)
union = public_factory(CompoundSelect._create_union, ".expression.union")
union_all = public_factory(
- CompoundSelect._create_union_all, ".expression.union_all")
+ CompoundSelect._create_union_all, ".expression.union_all"
+)
exists = public_factory(Exists, ".expression.exists")
nullsfirst = public_factory(
- UnaryExpression._create_nullsfirst, ".expression.nullsfirst")
+ UnaryExpression._create_nullsfirst, ".expression.nullsfirst"
+)
nullslast = public_factory(
- UnaryExpression._create_nullslast, ".expression.nullslast")
+ UnaryExpression._create_nullslast, ".expression.nullslast"
+)
asc = public_factory(UnaryExpression._create_asc, ".expression.asc")
desc = public_factory(UnaryExpression._create_desc, ".expression.desc")
distinct = public_factory(
- UnaryExpression._create_distinct, ".expression.distinct")
+ UnaryExpression._create_distinct, ".expression.distinct"
+)
type_coerce = public_factory(TypeCoerce, ".expression.type_coerce")
true = public_factory(True_._instance, ".expression.true")
false = public_factory(False_._instance, ".expression.false")
@@ -105,19 +211,30 @@ outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
insert = public_factory(Insert, ".expression.insert")
update = public_factory(Update, ".expression.update")
delete = public_factory(Delete, ".expression.delete")
-funcfilter = public_factory(
- FunctionFilter, ".expression.funcfilter")
+funcfilter = public_factory(FunctionFilter, ".expression.funcfilter")
# internal functions still being called from tests and the ORM,
# these might be better off in some other namespace
from .base import _from_objects
-from .elements import _literal_as_text, _clause_element_as_expr,\
- _is_column, _labeled, _only_column_elements, _string_or_unprintable, \
- _truncated_label, _clone, _cloned_difference, _cloned_intersection,\
- _column_as_key, _literal_as_binds, _select_iterables, \
- _corresponding_column_or_error, _literal_as_label_reference, \
- _expression_literal_as_text
+from .elements import (
+ _literal_as_text,
+ _clause_element_as_expr,
+ _is_column,
+ _labeled,
+ _only_column_elements,
+ _string_or_unprintable,
+ _truncated_label,
+ _clone,
+ _cloned_difference,
+ _cloned_intersection,
+ _column_as_key,
+ _literal_as_binds,
+ _select_iterables,
+ _corresponding_column_or_error,
+ _literal_as_label_reference,
+ _expression_literal_as_text,
+)
from .selectable import _interpret_as_from
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 4b4d2d463..883bb8cc3 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -10,10 +10,22 @@
"""
from . import sqltypes, schema
from .base import Executable, ColumnCollection
-from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
- literal_column, _type_from_args, ColumnElement, _clone,\
- Over, BindParameter, FunctionFilter, Grouping, WithinGroup, \
- BinaryExpression
+from .elements import (
+ ClauseList,
+ Cast,
+ Extract,
+ _literal_as_binds,
+ literal_column,
+ _type_from_args,
+ ColumnElement,
+ _clone,
+ Over,
+ BindParameter,
+ FunctionFilter,
+ Grouping,
+ WithinGroup,
+ BinaryExpression,
+)
from .selectable import FromClause, Select, Alias
from . import util as sqlutil
from . import operators
@@ -62,9 +74,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
args = [_literal_as_binds(c, self.name) for c in clauses]
self._has_args = self._has_args or bool(args)
self.clause_expr = ClauseList(
- operator=operators.comma_op,
- group_contents=True, *args).\
- self_group()
+ operator=operators.comma_op, group_contents=True, *args
+ ).self_group()
def _execute_on_connection(self, connection, multiparams, params):
return connection._execute_function(self, multiparams, params)
@@ -123,7 +134,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
partition_by=partition_by,
order_by=order_by,
rows=rows,
- range_=range_
+ range_=range_,
)
def within_group(self, *order_by):
@@ -233,16 +244,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
.. versionadded:: 1.3
"""
- return FunctionAsBinary(
- self, left_index, right_index
- )
+ return FunctionAsBinary(self, left_index, right_index)
@property
def _from_objects(self):
return self.clauses._from_objects
def get_children(self, **kwargs):
- return self.clause_expr,
+ return (self.clause_expr,)
def _copy_internals(self, clone=_clone, **kw):
self.clause_expr = clone(self.clause_expr, **kw)
@@ -336,24 +345,29 @@ class FunctionElement(Executable, ColumnElement, FromClause):
return self.select().execute()
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(None, obj, _compared_to_operator=operator,
- _compared_to_type=self.type, unique=True,
- type_=type_)
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ unique=True,
+ type_=type_,
+ )
def self_group(self, against=None):
# for the moment, we are parenthesizing all array-returning
# expressions against getitem. This may need to be made
# more portable if in the future we support other DBs
# besides postgresql.
- if against is operators.getitem and \
- isinstance(self.type, sqltypes.ARRAY):
+ if against is operators.getitem and isinstance(
+ self.type, sqltypes.ARRAY
+ ):
return Grouping(self)
else:
return super(FunctionElement, self).self_group(against=against)
class FunctionAsBinary(BinaryExpression):
-
def __init__(self, fn, left_index, right_index):
left = fn.clauses.clauses[left_index - 1]
right = fn.clauses.clauses[right_index - 1]
@@ -362,8 +376,11 @@ class FunctionAsBinary(BinaryExpression):
self.right_index = right_index
super(FunctionAsBinary, self).__init__(
- left, right, operators.function_as_comparison_op,
- type_=sqltypes.BOOLEANTYPE)
+ left,
+ right,
+ operators.function_as_comparison_op,
+ type_=sqltypes.BOOLEANTYPE,
+ )
@property
def left(self):
@@ -382,7 +399,7 @@ class FunctionAsBinary(BinaryExpression):
self.sql_function.clauses.clauses[self.right_index - 1] = value
def _copy_internals(self, **kw):
- clone = kw.pop('clone')
+ clone = kw.pop("clone")
self.sql_function = clone(self.sql_function, **kw)
super(FunctionAsBinary, self)._copy_internals(**kw)
@@ -396,13 +413,13 @@ class _FunctionGenerator(object):
def __getattr__(self, name):
# passthru __ attributes; fixes pydoc
- if name.startswith('__'):
+ if name.startswith("__"):
try:
return self.__dict__[name]
except KeyError:
raise AttributeError(name)
- elif name.endswith('_'):
+ elif name.endswith("_"):
name = name[0:-1]
f = _FunctionGenerator(**self.opts)
f.__names = list(self.__names) + [name]
@@ -426,8 +443,9 @@ class _FunctionGenerator(object):
if func is not None:
return func(*c, **o)
- return Function(self.__names[-1],
- packagenames=self.__names[0:-1], *c, **o)
+ return Function(
+ self.__names[-1], packagenames=self.__names[0:-1], *c, **o
+ )
func = _FunctionGenerator()
@@ -523,7 +541,7 @@ class Function(FunctionElement):
"""
- __visit_name__ = 'function'
+ __visit_name__ = "function"
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
@@ -532,30 +550,33 @@ class Function(FunctionElement):
new :class:`.Function` instances.
"""
- self.packagenames = kw.pop('packagenames', None) or []
+ self.packagenames = kw.pop("packagenames", None) or []
self.name = name
- self._bind = kw.get('bind', None)
- self.type = sqltypes.to_instance(kw.get('type_', None))
+ self._bind = kw.get("bind", None)
+ self.type = sqltypes.to_instance(kw.get("type_", None))
FunctionElement.__init__(self, *clauses, **kw)
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(self.name, obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type,
- type_=type_,
- unique=True)
+ return BindParameter(
+ self.name,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ )
class _GenericMeta(VisitableType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
- cls.name = name = clsdict.get('name', clsname)
- cls.identifier = identifier = clsdict.get('identifier', name)
- package = clsdict.pop('package', '_default')
+ cls.name = name = clsdict.get("name", clsname)
+ cls.identifier = identifier = clsdict.get("identifier", name)
+ package = clsdict.pop("package", "_default")
# legacy
- if '__return_type__' in clsdict:
- cls.type = clsdict['__return_type__']
+ if "__return_type__" in clsdict:
+ cls.type = clsdict["__return_type__"]
register_function(identifier, cls, package)
super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
@@ -635,17 +656,19 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
coerce_arguments = True
def __init__(self, *args, **kwargs):
- parsed_args = kwargs.pop('_parsed_args', None)
+ parsed_args = kwargs.pop("_parsed_args", None)
if parsed_args is None:
parsed_args = [_literal_as_binds(c, self.name) for c in args]
self._has_args = self._has_args or bool(parsed_args)
self.packagenames = []
- self._bind = kwargs.get('bind', None)
+ self._bind = kwargs.get("bind", None)
self.clause_expr = ClauseList(
- operator=operators.comma_op,
- group_contents=True, *parsed_args).self_group()
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ ).self_group()
self.type = sqltypes.to_instance(
- kwargs.pop("type_", None) or getattr(self, 'type', None))
+ kwargs.pop("type_", None) or getattr(self, "type", None)
+ )
+
register_function("cast", Cast)
register_function("extract", Extract)
@@ -660,13 +683,15 @@ class next_value(GenericFunction):
that does not provide support for sequences.
"""
+
type = sqltypes.Integer()
name = "next_value"
def __init__(self, seq, **kw):
- assert isinstance(seq, schema.Sequence), \
- "next_value() accepts a Sequence object as input."
- self._bind = kw.get('bind', None)
+ assert isinstance(
+ seq, schema.Sequence
+ ), "next_value() accepts a Sequence object as input."
+ self._bind = kw.get("bind", None)
self.sequence = seq
@property
@@ -684,8 +709,8 @@ class ReturnTypeFromArgs(GenericFunction):
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c, self.name) for c in args]
- kwargs.setdefault('type_', _type_from_args(args))
- kwargs['_parsed_args'] = args
+ kwargs.setdefault("type_", _type_from_args(args))
+ kwargs["_parsed_args"] = args
super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
@@ -733,7 +758,7 @@ class count(GenericFunction):
def __init__(self, expression=None, **kwargs):
if expression is None:
- expression = literal_column('*')
+ expression = literal_column("*")
super(count, self).__init__(expression, **kwargs)
@@ -797,15 +822,15 @@ class array_agg(GenericFunction):
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c) for c in args]
- default_array_type = kwargs.pop('_default_array_type', sqltypes.ARRAY)
- if 'type_' not in kwargs:
+ default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
+ if "type_" not in kwargs:
type_from_args = _type_from_args(args)
if isinstance(type_from_args, sqltypes.ARRAY):
- kwargs['type_'] = type_from_args
+ kwargs["type_"] = type_from_args
else:
- kwargs['type_'] = default_array_type(type_from_args)
- kwargs['_parsed_args'] = args
+ kwargs["type_"] = default_array_type(type_from_args)
+ kwargs["_parsed_args"] = args
super(array_agg, self).__init__(*args, **kwargs)
@@ -883,6 +908,7 @@ class rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Integer()
@@ -897,6 +923,7 @@ class dense_rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Integer()
@@ -911,6 +938,7 @@ class percent_rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Numeric()
@@ -925,6 +953,7 @@ class cume_dist(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Numeric()
diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py
index 0107ce724..144cc4dfc 100644
--- a/lib/sqlalchemy/sql/naming.py
+++ b/lib/sqlalchemy/sql/naming.py
@@ -10,8 +10,16 @@
"""
-from .schema import Constraint, ForeignKeyConstraint, PrimaryKeyConstraint, \
- UniqueConstraint, CheckConstraint, Index, Table, Column
+from .schema import (
+ Constraint,
+ ForeignKeyConstraint,
+ PrimaryKeyConstraint,
+ UniqueConstraint,
+ CheckConstraint,
+ Index,
+ Table,
+ Column,
+)
from .. import event, events
from .. import exc
from .elements import _truncated_label, _defer_name, _defer_none_name, conv
@@ -19,7 +27,6 @@ import re
class ConventionDict(object):
-
def __init__(self, const, table, convention):
self.const = const
self._is_fk = isinstance(const, ForeignKeyConstraint)
@@ -79,8 +86,8 @@ class ConventionDict(object):
def __getitem__(self, key):
if key in self.convention:
return self.convention[key](self.const, self.table)
- elif hasattr(self, '_key_%s' % key):
- return getattr(self, '_key_%s' % key)()
+ elif hasattr(self, "_key_%s" % key):
+ return getattr(self, "_key_%s" % key)()
else:
col_template = re.match(r".*_?column_(\d+)(_?N)?_.+", key)
if col_template:
@@ -108,12 +115,13 @@ class ConventionDict(object):
return getattr(self, attr)(idx)
raise KeyError(key)
+
_prefix_dict = {
Index: "ix",
PrimaryKeyConstraint: "pk",
CheckConstraint: "ck",
UniqueConstraint: "uq",
- ForeignKeyConstraint: "fk"
+ ForeignKeyConstraint: "fk",
}
@@ -134,15 +142,18 @@ def _constraint_name_for_table(const, table):
if isinstance(const.name, conv):
return const.name
- elif convention is not None and \
- not isinstance(const.name, conv) and \
- (
- const.name is None or
- "constraint_name" in convention or
- isinstance(const.name, _defer_name)):
+ elif (
+ convention is not None
+ and not isinstance(const.name, conv)
+ and (
+ const.name is None
+ or "constraint_name" in convention
+ or isinstance(const.name, _defer_name)
+ )
+ ):
return conv(
- convention % ConventionDict(const, table,
- metadata.naming_convention)
+ convention
+ % ConventionDict(const, table, metadata.naming_convention)
)
elif isinstance(convention, _defer_none_name):
return None
@@ -155,9 +166,11 @@ def _constraint_name(const, table):
# for column-attached constraint, set another event
# to link the column attached to the table as this constraint
# associated with the table.
- event.listen(table, "after_parent_attach",
- lambda col, table: _constraint_name(const, table)
- )
+ event.listen(
+ table,
+ "after_parent_attach",
+ lambda col, table: _constraint_name(const, table),
+ )
elif isinstance(table, Table):
if isinstance(const.name, (conv, _defer_name)):
return
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index 5b4a28a06..2b843d751 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -13,8 +13,25 @@
from .. import util
from operator import (
- and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg,
- getitem, lshift, rshift, contains
+ and_,
+ or_,
+ inv,
+ add,
+ mul,
+ sub,
+ mod,
+ truediv,
+ lt,
+ le,
+ ne,
+ gt,
+ ge,
+ eq,
+ neg,
+ getitem,
+ lshift,
+ rshift,
+ contains,
)
if util.py2k:
@@ -37,6 +54,7 @@ class Operators(object):
:class:`.ColumnOperators`.
"""
+
__slots__ = ()
def __and__(self, other):
@@ -105,8 +123,8 @@ class Operators(object):
return self.operate(inv)
def op(
- self, opstring, precedence=0, is_comparison=False,
- return_type=None):
+ self, opstring, precedence=0, is_comparison=False, return_type=None
+ ):
"""produce a generic operator function.
e.g.::
@@ -168,6 +186,7 @@ class Operators(object):
def against(other):
return operator(self, other)
+
return against
def bool_op(self, opstring, precedence=0):
@@ -247,12 +266,18 @@ class custom_op(object):
:meth:`.Operators.bool_op`
"""
- __name__ = 'custom_op'
+
+ __name__ = "custom_op"
def __init__(
- self, opstring, precedence=0, is_comparison=False,
- return_type=None, natural_self_precedent=False,
- eager_grouping=False):
+ self,
+ opstring,
+ precedence=0,
+ is_comparison=False,
+ return_type=None,
+ natural_self_precedent=False,
+ eager_grouping=False,
+ ):
self.opstring = opstring
self.precedence = precedence
self.is_comparison = is_comparison
@@ -263,8 +288,7 @@ class custom_op(object):
)
def __eq__(self, other):
- return isinstance(other, custom_op) and \
- other.opstring == self.opstring
+ return isinstance(other, custom_op) and other.opstring == self.opstring
def __hash__(self):
return id(self)
@@ -1138,6 +1162,7 @@ class ColumnOperators(Operators):
"""
return self.reverse_operate(truediv, other)
+
_commutative = {eq, ne, add, mul}
_comparison = {eq, ne, lt, gt, ge, le}
@@ -1261,20 +1286,18 @@ def _escaped_like_impl(fn, other, escape, autoescape):
if autoescape:
if autoescape is not True:
util.warn(
- "The autoescape parameter is now a simple boolean True/False")
+ "The autoescape parameter is now a simple boolean True/False"
+ )
if escape is None:
- escape = '/'
+ escape = "/"
if not isinstance(other, util.compat.string_types):
raise TypeError("String value expected when autoescape=True")
- if escape not in ('%', '_'):
+ if escape not in ("%", "_"):
other = other.replace(escape, escape + escape)
- other = (
- other.replace('%', escape + '%').
- replace('_', escape + '_')
- )
+ other = other.replace("%", escape + "%").replace("_", escape + "_")
return fn(other, escape=escape)
@@ -1362,8 +1385,7 @@ def json_path_getitem_op(a, b):
def is_comparison(op):
- return op in _comparison or \
- isinstance(op, custom_op) and op.is_comparison
+ return op in _comparison or isinstance(op, custom_op) and op.is_comparison
def is_commutative(op):
@@ -1371,13 +1393,16 @@ def is_commutative(op):
def is_ordering_modifier(op):
- return op in (asc_op, desc_op,
- nullsfirst_op, nullslast_op)
+ return op in (asc_op, desc_op, nullsfirst_op, nullslast_op)
def is_natural_self_precedent(op):
- return op in _natural_self_precedent or \
- isinstance(op, custom_op) and op.natural_self_precedent
+ return (
+ op in _natural_self_precedent
+ or isinstance(op, custom_op)
+ and op.natural_self_precedent
+ )
+
_booleans = (inv, istrue, isfalse, and_, or_)
@@ -1385,12 +1410,8 @@ _booleans = (inv, istrue, isfalse, and_, or_)
def is_boolean(op):
return is_comparison(op) or op in _booleans
-_mirror = {
- gt: lt,
- ge: le,
- lt: gt,
- le: ge
-}
+
+_mirror = {gt: lt, ge: le, lt: gt, le: ge}
def mirror(op):
@@ -1404,17 +1425,18 @@ def mirror(op):
_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
-_natural_self_precedent = _associative.union([
- getitem, json_getitem_op, json_path_getitem_op])
+_natural_self_precedent = _associative.union(
+ [getitem, json_getitem_op, json_path_getitem_op]
+)
"""Operators where if we have (a op b) op c, we don't want to
parenthesize (a op b).
"""
-_asbool = util.symbol('_asbool', canonical=-10)
-_smallest = util.symbol('_smallest', canonical=-100)
-_largest = util.symbol('_largest', canonical=100)
+_asbool = util.symbol("_asbool", canonical=-10)
+_smallest = util.symbol("_smallest", canonical=-100)
+_largest = util.symbol("_largest", canonical=100)
_PRECEDENCE = {
from_: 15,
@@ -1424,7 +1446,6 @@ _PRECEDENCE = {
getitem: 15,
json_getitem_op: 15,
json_path_getitem_op: 15,
-
mul: 8,
truediv: 8,
div: 8,
@@ -1432,22 +1453,17 @@ _PRECEDENCE = {
neg: 8,
add: 7,
sub: 7,
-
concat_op: 6,
-
match_op: 5,
notmatch_op: 5,
-
ilike_op: 5,
notilike_op: 5,
like_op: 5,
notlike_op: 5,
in_op: 5,
notin_op: 5,
-
is_: 5,
isnot: 5,
-
eq: 5,
ne: 5,
is_distinct_from: 5,
@@ -1458,7 +1474,6 @@ _PRECEDENCE = {
lt: 5,
ge: 5,
le: 5,
-
between_op: 5,
notbetween_op: 5,
distinct_op: 5,
@@ -1468,17 +1483,14 @@ _PRECEDENCE = {
and_: 3,
or_: 2,
comma_op: -1,
-
desc_op: 3,
asc_op: 3,
collate: 4,
-
as_: -1,
exists: 0,
-
_asbool: -10,
_smallest: _smallest,
- _largest: _largest
+ _largest: _largest,
}
@@ -1486,7 +1498,6 @@ def is_precedent(operator, against):
if operator is against and is_natural_self_precedent(operator):
return False
else:
- return (_PRECEDENCE.get(operator,
- getattr(operator, 'precedence', _smallest)) <=
- _PRECEDENCE.get(against,
- getattr(against, 'precedence', _largest)))
+ return _PRECEDENCE.get(
+ operator, getattr(operator, "precedence", _smallest)
+ ) <= _PRECEDENCE.get(against, getattr(against, "precedence", _largest))
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 3e9aa174a..d6c3f5000 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -36,25 +36,31 @@ import operator
from . import visitors
from . import type_api
from .base import _bind_or_error, ColumnCollection
-from .elements import ClauseElement, ColumnClause, \
- _as_truncated, TextClause, _literal_as_text,\
- ColumnElement, quoted_name
+from .elements import (
+ ClauseElement,
+ ColumnClause,
+ _as_truncated,
+ TextClause,
+ _literal_as_text,
+ ColumnElement,
+ quoted_name,
+)
from .selectable import TableClause
import collections
import sqlalchemy
from . import ddl
-RETAIN_SCHEMA = util.symbol('retain_schema')
+RETAIN_SCHEMA = util.symbol("retain_schema")
BLANK_SCHEMA = util.symbol(
- 'blank_schema',
+ "blank_schema",
"""Symbol indicating that a :class:`.Table` or :class:`.Sequence`
should have 'None' for its schema, even if the parent
:class:`.MetaData` has specified a schema.
.. versionadded:: 1.0.14
- """
+ """,
)
@@ -69,11 +75,15 @@ def _get_table_key(name, schema):
# break an import cycle
def _copy_expression(expression, source_table, target_table):
def replace(col):
- if isinstance(col, Column) and \
- col.table is source_table and col.key in source_table.c:
+ if (
+ isinstance(col, Column)
+ and col.table is source_table
+ and col.key in source_table.c
+ ):
return target_table.c[col.key]
else:
return None
+
return visitors.replacement_traverse(expression, {}, replace)
@@ -81,7 +91,7 @@ def _copy_expression(expression, source_table, target_table):
class SchemaItem(SchemaEventTarget, visitors.Visitable):
"""Base class for items that define a database schema."""
- __visit_name__ = 'schema_item'
+ __visit_name__ = "schema_item"
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
@@ -95,10 +105,10 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
return []
def __repr__(self):
- return util.generic_repr(self, omit_kwarg=['info'])
+ return util.generic_repr(self, omit_kwarg=["info"])
@property
- @util.deprecated('0.9', 'Use ``<obj>.name.quote``')
+ @util.deprecated("0.9", "Use ``<obj>.name.quote``")
def quote(self):
"""Return the value of the ``quote`` flag passed
to this schema object, for those schema items which
@@ -121,7 +131,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
return {}
def _schema_item_copy(self, schema_item):
- if 'info' in self.__dict__:
+ if "info" in self.__dict__:
schema_item.info = self.info.copy()
schema_item.dispatch._update(self.dispatch)
return schema_item
@@ -396,7 +406,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"""
- __visit_name__ = 'table'
+ __visit_name__ = "table"
def __new__(cls, *args, **kw):
if not args:
@@ -408,26 +418,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
except IndexError:
raise TypeError("Table() takes at least two arguments")
- schema = kw.get('schema', None)
+ schema = kw.get("schema", None)
if schema is None:
schema = metadata.schema
elif schema is BLANK_SCHEMA:
schema = None
- keep_existing = kw.pop('keep_existing', False)
- extend_existing = kw.pop('extend_existing', False)
- if 'useexisting' in kw:
+ keep_existing = kw.pop("keep_existing", False)
+ extend_existing = kw.pop("extend_existing", False)
+ if "useexisting" in kw:
msg = "useexisting is deprecated. Use extend_existing."
util.warn_deprecated(msg)
if extend_existing:
msg = "useexisting is synonymous with extend_existing."
raise exc.ArgumentError(msg)
- extend_existing = kw.pop('useexisting', False)
+ extend_existing = kw.pop("useexisting", False)
if keep_existing and extend_existing:
msg = "keep_existing and extend_existing are mutually exclusive."
raise exc.ArgumentError(msg)
- mustexist = kw.pop('mustexist', False)
+ mustexist = kw.pop("mustexist", False)
key = _get_table_key(name, schema)
if key in metadata.tables:
if not keep_existing and not extend_existing and bool(args):
@@ -436,15 +446,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"instance. Specify 'extend_existing=True' "
"to redefine "
"options and columns on an "
- "existing Table object." % key)
+ "existing Table object." % key
+ )
table = metadata.tables[key]
if extend_existing:
table._init_existing(*args, **kw)
return table
else:
if mustexist:
- raise exc.InvalidRequestError(
- "Table '%s' not defined" % (key))
+ raise exc.InvalidRequestError("Table '%s' not defined" % (key))
table = object.__new__(cls)
table.dispatch.before_parent_attach(table, metadata)
metadata._add_table(name, schema, table)
@@ -457,7 +467,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
metadata._remove_table(name, schema)
@property
- @util.deprecated('0.9', 'Use ``table.schema.quote``')
+ @util.deprecated("0.9", "Use ``table.schema.quote``")
def quote_schema(self):
"""Return the value of the ``quote_schema`` flag passed
to this :class:`.Table`.
@@ -478,23 +488,25 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def _init(self, name, metadata, *args, **kwargs):
super(Table, self).__init__(
- quoted_name(name, kwargs.pop('quote', None)))
+ quoted_name(name, kwargs.pop("quote", None))
+ )
self.metadata = metadata
- self.schema = kwargs.pop('schema', None)
+ self.schema = kwargs.pop("schema", None)
if self.schema is None:
self.schema = metadata.schema
elif self.schema is BLANK_SCHEMA:
self.schema = None
else:
- quote_schema = kwargs.pop('quote_schema', None)
+ quote_schema = kwargs.pop("quote_schema", None)
self.schema = quoted_name(self.schema, quote_schema)
self.indexes = set()
self.constraints = set()
self._columns = ColumnCollection()
- PrimaryKeyConstraint(_implicit_generated=True).\
- _set_parent_with_dispatch(self)
+ PrimaryKeyConstraint(
+ _implicit_generated=True
+ )._set_parent_with_dispatch(self)
self.foreign_keys = set()
self._extra_dependencies = set()
if self.schema is not None:
@@ -502,26 +514,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
else:
self.fullname = self.name
- autoload_with = kwargs.pop('autoload_with', None)
- autoload = kwargs.pop('autoload', autoload_with is not None)
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
# this argument is only used with _init_existing()
- kwargs.pop('autoload_replace', True)
+ kwargs.pop("autoload_replace", True)
_extend_on = kwargs.pop("_extend_on", None)
- include_columns = kwargs.pop('include_columns', None)
+ include_columns = kwargs.pop("include_columns", None)
- self.implicit_returning = kwargs.pop('implicit_returning', True)
+ self.implicit_returning = kwargs.pop("implicit_returning", True)
- self.comment = kwargs.pop('comment', None)
+ self.comment = kwargs.pop("comment", None)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
- if 'listeners' in kwargs:
- listeners = kwargs.pop('listeners')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+ if "listeners" in kwargs:
+ listeners = kwargs.pop("listeners")
for evt, fn in listeners:
event.listen(self, evt, fn)
- self._prefixes = kwargs.pop('prefixes', [])
+ self._prefixes = kwargs.pop("prefixes", [])
self._extra_kwargs(**kwargs)
@@ -530,21 +542,29 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
# circular foreign keys
if autoload:
self._autoload(
- metadata, autoload_with,
- include_columns, _extend_on=_extend_on)
+ metadata, autoload_with, include_columns, _extend_on=_extend_on
+ )
# initialize all the column, etc. objects. done after reflection to
# allow user-overrides
self._init_items(*args)
- def _autoload(self, metadata, autoload_with, include_columns,
- exclude_columns=(), _extend_on=None):
+ def _autoload(
+ self,
+ metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns=(),
+ _extend_on=None,
+ ):
if autoload_with:
autoload_with.run_callable(
autoload_with.dialect.reflecttable,
- self, include_columns, exclude_columns,
- _extend_on=_extend_on
+ self,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
)
else:
bind = _bind_or_error(
@@ -553,11 +573,14 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"Pass an engine to the Table via "
"autoload_with=<someengine>, "
"or associate the MetaData with an engine via "
- "metadata.bind=<someengine>")
+ "metadata.bind=<someengine>",
+ )
bind.run_callable(
bind.dialect.reflecttable,
- self, include_columns, exclude_columns,
- _extend_on=_extend_on
+ self,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
)
@property
@@ -582,34 +605,36 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
return set(fkc.constraint for fkc in self.foreign_keys)
def _init_existing(self, *args, **kwargs):
- autoload_with = kwargs.pop('autoload_with', None)
- autoload = kwargs.pop('autoload', autoload_with is not None)
- autoload_replace = kwargs.pop('autoload_replace', True)
- schema = kwargs.pop('schema', None)
- _extend_on = kwargs.pop('_extend_on', None)
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
+ autoload_replace = kwargs.pop("autoload_replace", True)
+ schema = kwargs.pop("schema", None)
+ _extend_on = kwargs.pop("_extend_on", None)
if schema and schema != self.schema:
raise exc.ArgumentError(
"Can't change schema of existing table from '%s' to '%s'",
- (self.schema, schema))
+ (self.schema, schema),
+ )
- include_columns = kwargs.pop('include_columns', None)
+ include_columns = kwargs.pop("include_columns", None)
if include_columns is not None:
for c in self.c:
if c.name not in include_columns:
self._columns.remove(c)
- for key in ('quote', 'quote_schema'):
+ for key in ("quote", "quote_schema"):
if key in kwargs:
raise exc.ArgumentError(
- "Can't redefine 'quote' or 'quote_schema' arguments")
+ "Can't redefine 'quote' or 'quote_schema' arguments"
+ )
- if 'comment' in kwargs:
- self.comment = kwargs.pop('comment', None)
+ if "comment" in kwargs:
+ self.comment = kwargs.pop("comment", None)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
if autoload:
if not autoload_replace:
@@ -620,8 +645,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
else:
exclude_columns = ()
self._autoload(
- self.metadata, autoload_with,
- include_columns, exclude_columns, _extend_on=_extend_on)
+ self.metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
+ )
self._extra_kwargs(**kwargs)
self._init_items(*args)
@@ -653,10 +682,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
return _get_table_key(self.name, self.schema)
def __repr__(self):
- return "Table(%s)" % ', '.join(
- [repr(self.name)] + [repr(self.metadata)] +
- [repr(x) for x in self.columns] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']])
+ return "Table(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.metadata)]
+ + [repr(x) for x in self.columns]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in ["schema"]]
+ )
def __str__(self):
return _get_table_key(self.description, self.schema)
@@ -735,17 +766,19 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def adapt_listener(target, connection, **kw):
listener(event_name, target, connection)
- event.listen(self, "" + event_name.replace('-', '_'), adapt_listener)
+ event.listen(self, "" + event_name.replace("-", "_"), adapt_listener)
def _set_parent(self, metadata):
metadata._add_table(self.name, self.schema, self)
self.metadata = metadata
- def get_children(self, column_collections=True,
- schema_visitor=False, **kw):
+ def get_children(
+ self, column_collections=True, schema_visitor=False, **kw
+ ):
if not schema_visitor:
return TableClause.get_children(
- self, column_collections=column_collections, **kw)
+ self, column_collections=column_collections, **kw
+ )
else:
if column_collections:
return list(self.columns)
@@ -758,8 +791,9 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
if bind is None:
bind = _bind_or_error(self)
- return bind.run_callable(bind.dialect.has_table,
- self.name, schema=self.schema)
+ return bind.run_callable(
+ bind.dialect.has_table, self.name, schema=self.schema
+ )
def create(self, bind=None, checkfirst=False):
"""Issue a ``CREATE`` statement for this
@@ -774,9 +808,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=False):
"""Issue a ``DROP`` statement for this
@@ -790,12 +822,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst)
-
- def tometadata(self, metadata, schema=RETAIN_SCHEMA,
- referred_schema_fn=None, name=None):
+ bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ def tometadata(
+ self,
+ metadata,
+ schema=RETAIN_SCHEMA,
+ referred_schema_fn=None,
+ name=None,
+ ):
"""Return a copy of this :class:`.Table` associated with a different
:class:`.MetaData`.
@@ -868,29 +903,37 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
schema = metadata.schema
key = _get_table_key(name, schema)
if key in metadata.tables:
- util.warn("Table '%s' already exists within the given "
- "MetaData - not copying." % self.description)
+ util.warn(
+ "Table '%s' already exists within the given "
+ "MetaData - not copying." % self.description
+ )
return metadata.tables[key]
args = []
for c in self.columns:
args.append(c.copy(schema=schema))
table = Table(
- name, metadata, schema=schema,
+ name,
+ metadata,
+ schema=schema,
comment=self.comment,
- *args, **self.kwargs
+ *args,
+ **self.kwargs
)
for c in self.constraints:
if isinstance(c, ForeignKeyConstraint):
referred_schema = c._referred_schema
if referred_schema_fn:
fk_constraint_schema = referred_schema_fn(
- self, schema, c, referred_schema)
+ self, schema, c, referred_schema
+ )
else:
fk_constraint_schema = (
- schema if referred_schema == self.schema else None)
+ schema if referred_schema == self.schema else None
+ )
table.append_constraint(
- c.copy(schema=fk_constraint_schema, target_table=table))
+ c.copy(schema=fk_constraint_schema, target_table=table)
+ )
elif not c._type_bound:
# skip unique constraints that would be generated
# by the 'unique' flag on Column
@@ -898,25 +941,30 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
continue
table.append_constraint(
- c.copy(schema=schema, target_table=table))
+ c.copy(schema=schema, target_table=table)
+ )
for index in self.indexes:
# skip indexes that would be generated
# by the 'index' flag on Column
if index._column_flag:
continue
- Index(index.name,
- unique=index.unique,
- *[_copy_expression(expr, self, table)
- for expr in index.expressions],
- _table=table,
- **index.kwargs)
+ Index(
+ index.name,
+ unique=index.unique,
+ *[
+ _copy_expression(expr, self, table)
+ for expr in index.expressions
+ ],
+ _table=table,
+ **index.kwargs
+ )
return self._schema_item_copy(table)
class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""Represents a column in a database table."""
- __visit_name__ = 'column'
+ __visit_name__ = "column"
def __init__(self, *args, **kwargs):
r"""
@@ -1192,14 +1240,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""
- name = kwargs.pop('name', None)
- type_ = kwargs.pop('type_', None)
+ name = kwargs.pop("name", None)
+ type_ = kwargs.pop("type_", None)
args = list(args)
if args:
if isinstance(args[0], util.string_types):
if name is not None:
raise exc.ArgumentError(
- "May not pass name positionally and as a keyword.")
+ "May not pass name positionally and as a keyword."
+ )
name = args.pop(0)
if args:
coltype = args[0]
@@ -1207,40 +1256,42 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if hasattr(coltype, "_sqla_type"):
if type_ is not None:
raise exc.ArgumentError(
- "May not pass type_ positionally and as a keyword.")
+ "May not pass type_ positionally and as a keyword."
+ )
type_ = args.pop(0)
if name is not None:
- name = quoted_name(name, kwargs.pop('quote', None))
+ name = quoted_name(name, kwargs.pop("quote", None))
elif "quote" in kwargs:
- raise exc.ArgumentError("Explicit 'name' is required when "
- "sending 'quote' argument")
+ raise exc.ArgumentError(
+ "Explicit 'name' is required when " "sending 'quote' argument"
+ )
super(Column, self).__init__(name, type_)
- self.key = kwargs.pop('key', name)
- self.primary_key = kwargs.pop('primary_key', False)
- self.nullable = kwargs.pop('nullable', not self.primary_key)
- self.default = kwargs.pop('default', None)
- self.server_default = kwargs.pop('server_default', None)
- self.server_onupdate = kwargs.pop('server_onupdate', None)
+ self.key = kwargs.pop("key", name)
+ self.primary_key = kwargs.pop("primary_key", False)
+ self.nullable = kwargs.pop("nullable", not self.primary_key)
+ self.default = kwargs.pop("default", None)
+ self.server_default = kwargs.pop("server_default", None)
+ self.server_onupdate = kwargs.pop("server_onupdate", None)
# these default to None because .index and .unique is *not*
# an informational flag about Column - there can still be an
# Index or UniqueConstraint referring to this Column.
- self.index = kwargs.pop('index', None)
- self.unique = kwargs.pop('unique', None)
+ self.index = kwargs.pop("index", None)
+ self.unique = kwargs.pop("unique", None)
- self.system = kwargs.pop('system', False)
- self.doc = kwargs.pop('doc', None)
- self.onupdate = kwargs.pop('onupdate', None)
- self.autoincrement = kwargs.pop('autoincrement', "auto")
+ self.system = kwargs.pop("system", False)
+ self.doc = kwargs.pop("doc", None)
+ self.onupdate = kwargs.pop("onupdate", None)
+ self.autoincrement = kwargs.pop("autoincrement", "auto")
self.constraints = set()
self.foreign_keys = set()
- self.comment = kwargs.pop('comment', None)
+ self.comment = kwargs.pop("comment", None)
# check if this Column is proxying another column
- if '_proxies' in kwargs:
- self._proxies = kwargs.pop('_proxies')
+ if "_proxies" in kwargs:
+ self._proxies = kwargs.pop("_proxies")
# otherwise, add DDL-related events
elif isinstance(self.type, SchemaEventTarget):
self.type._set_parent_with_dispatch(self)
@@ -1249,14 +1300,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if isinstance(self.default, (ColumnDefault, Sequence)):
args.append(self.default)
else:
- if getattr(self.type, '_warn_on_bytestring', False):
+ if getattr(self.type, "_warn_on_bytestring", False):
if isinstance(self.default, util.binary_type):
util.warn(
"Unicode column '%s' has non-unicode "
- "default value %r specified." % (
- self.key,
- self.default
- ))
+ "default value %r specified."
+ % (self.key, self.default)
+ )
args.append(ColumnDefault(self.default))
if self.server_default is not None:
@@ -1275,30 +1325,31 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if isinstance(self.server_onupdate, FetchedValue):
args.append(self.server_onupdate._as_for_update(True))
else:
- args.append(DefaultClause(self.server_onupdate,
- for_update=True))
+ args.append(
+ DefaultClause(self.server_onupdate, for_update=True)
+ )
self._init_items(*args)
util.set_creation_order(self)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
self._extra_kwargs(**kwargs)
def _extra_kwargs(self, **kwargs):
self._validate_dialect_kwargs(kwargs)
-# @property
-# def quote(self):
-# return getattr(self.name, "quote", None)
+ # @property
+ # def quote(self):
+ # return getattr(self.name, "quote", None)
def __str__(self):
if self.name is None:
return "(no name)"
elif self.table is not None:
if self.table.named_with_column:
- return (self.table.description + "." + self.description)
+ return self.table.description + "." + self.description
else:
return self.description
else:
@@ -1320,40 +1371,47 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
def __repr__(self):
kwarg = []
if self.key != self.name:
- kwarg.append('key')
+ kwarg.append("key")
if self.primary_key:
- kwarg.append('primary_key')
+ kwarg.append("primary_key")
if not self.nullable:
- kwarg.append('nullable')
+ kwarg.append("nullable")
if self.onupdate:
- kwarg.append('onupdate')
+ kwarg.append("onupdate")
if self.default:
- kwarg.append('default')
+ kwarg.append("default")
if self.server_default:
- kwarg.append('server_default')
- return "Column(%s)" % ', '.join(
- [repr(self.name)] + [repr(self.type)] +
- [repr(x) for x in self.foreign_keys if x is not None] +
- [repr(x) for x in self.constraints] +
- [(self.table is not None and "table=<%s>" %
- self.table.description or "table=None")] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg])
+ kwarg.append("server_default")
+ return "Column(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.type)]
+ + [repr(x) for x in self.foreign_keys if x is not None]
+ + [repr(x) for x in self.constraints]
+ + [
+ (
+ self.table is not None
+ and "table=<%s>" % self.table.description
+ or "table=None"
+ )
+ ]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
+ )
def _set_parent(self, table):
if not self.name:
raise exc.ArgumentError(
"Column must be constructed with a non-blank name or "
- "assign a non-blank .name before adding to a Table.")
+ "assign a non-blank .name before adding to a Table."
+ )
if self.key is None:
self.key = self.name
- existing = getattr(self, 'table', None)
+ existing = getattr(self, "table", None)
if existing is not None and existing is not table:
raise exc.ArgumentError(
- "Column object '%s' already assigned to Table '%s'" % (
- self.key,
- existing.description
- ))
+ "Column object '%s' already assigned to Table '%s'"
+ % (self.key, existing.description)
+ )
if self.key in table._columns:
col = table._columns.get(self.key)
@@ -1373,8 +1431,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
elif self.key in table.primary_key:
raise exc.ArgumentError(
"Trying to redefine primary-key column '%s' as a "
- "non-primary-key column on table '%s'" % (
- self.key, table.fullname))
+ "non-primary-key column on table '%s'"
+ % (self.key, table.fullname)
+ )
self.table = table
@@ -1383,7 +1442,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
raise exc.ArgumentError(
"The 'index' keyword argument on Column is boolean only. "
"To create indexes with a specific name, create an "
- "explicit Index object external to the Table.")
+ "explicit Index object external to the Table."
+ )
Index(None, self, unique=bool(self.unique), _column_flag=True)
elif self.unique:
if isinstance(self.unique, util.string_types):
@@ -1392,9 +1452,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"only. To create unique constraints or indexes with a "
"specific name, append an explicit UniqueConstraint to "
"the Table's list of elements, or create an explicit "
- "Index object external to the Table.")
+ "Index object external to the Table."
+ )
table.append_constraint(
- UniqueConstraint(self.key, _column_flag=True))
+ UniqueConstraint(self.key, _column_flag=True)
+ )
self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table))
@@ -1413,7 +1475,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if self.table is not None:
fn(self, self.table)
else:
- event.listen(self, 'after_parent_attach', fn)
+ event.listen(self, "after_parent_attach", fn)
def copy(self, **kw):
"""Create a copy of this ``Column``, unitialized.
@@ -1423,9 +1485,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""
# Constraint objects plus non-constraint-bound ForeignKey objects
- args = \
- [c.copy(**kw) for c in self.constraints if not c._type_bound] + \
- [c.copy(**kw) for c in self.foreign_keys if not c.constraint]
+ args = [
+ c.copy(**kw) for c in self.constraints if not c._type_bound
+ ] + [c.copy(**kw) for c in self.foreign_keys if not c.constraint]
type_ = self.type
if isinstance(type_, SchemaEventTarget):
@@ -1452,8 +1514,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
)
return self._schema_item_copy(c)
- def _make_proxy(self, selectable, name=None, key=None,
- name_is_truncatable=False, **kw):
+ def _make_proxy(
+ self, selectable, name=None, key=None, name_is_truncatable=False, **kw
+ ):
"""Create a *proxy* for this column.
This is a copy of this ``Column`` referenced by a different parent
@@ -1462,22 +1525,28 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
information is not transferred.
"""
- fk = [ForeignKey(f.column, _constraint=f.constraint)
- for f in self.foreign_keys]
+ fk = [
+ ForeignKey(f.column, _constraint=f.constraint)
+ for f in self.foreign_keys
+ ]
if name is None and self.name is None:
raise exc.InvalidRequestError(
"Cannot initialize a sub-selectable"
" with this Column object until its 'name' has "
- "been assigned.")
+ "been assigned."
+ )
try:
c = self._constructor(
- _as_truncated(name or self.name) if
- name_is_truncatable else (name or self.name),
+ _as_truncated(name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
self.type,
key=key if key else name if name else self.key,
primary_key=self.primary_key,
nullable=self.nullable,
- _proxies=[self], *fk)
+ _proxies=[self],
+ *fk
+ )
except TypeError:
util.raise_from_cause(
TypeError(
@@ -1485,7 +1554,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"Ensure the class includes a _constructor() "
"attribute or method which accepts the "
"standard Column constructor arguments, or "
- "references the Column class itself." % self.__class__)
+ "references the Column class itself." % self.__class__
+ )
)
c.table = selectable
@@ -1499,9 +1569,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
def get_children(self, schema_visitor=False, **kwargs):
if schema_visitor:
- return [x for x in (self.default, self.onupdate)
- if x is not None] + \
- list(self.foreign_keys) + list(self.constraints)
+ return (
+ [x for x in (self.default, self.onupdate) if x is not None]
+ + list(self.foreign_keys)
+ + list(self.constraints)
+ )
else:
return ColumnClause.get_children(self, **kwargs)
@@ -1543,13 +1615,23 @@ class ForeignKey(DialectKWArgs, SchemaItem):
"""
- __visit_name__ = 'foreign_key'
-
- def __init__(self, column, _constraint=None, use_alter=False, name=None,
- onupdate=None, ondelete=None, deferrable=None,
- initially=None, link_to_name=False, match=None,
- info=None,
- **dialect_kw):
+ __visit_name__ = "foreign_key"
+
+ def __init__(
+ self,
+ column,
+ _constraint=None,
+ use_alter=False,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ link_to_name=False,
+ match=None,
+ info=None,
+ **dialect_kw
+ ):
r"""
Construct a column-level FOREIGN KEY.
@@ -1626,7 +1708,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if isinstance(self._colspec, util.string_types):
self._table_column = None
else:
- if hasattr(self._colspec, '__clause_element__'):
+ if hasattr(self._colspec, "__clause_element__"):
self._table_column = self._colspec.__clause_element__()
else:
self._table_column = self._colspec
@@ -1634,9 +1716,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if not isinstance(self._table_column, ColumnClause):
raise exc.ArgumentError(
"String, Column, or Column-bound argument "
- "expected, got %r" % self._table_column)
+ "expected, got %r" % self._table_column
+ )
elif not isinstance(
- self._table_column.table, (util.NoneType, TableClause)):
+ self._table_column.table, (util.NoneType, TableClause)
+ ):
raise exc.ArgumentError(
"ForeignKey received Column not bound "
"to a Table, got: %r" % self._table_column.table
@@ -1715,7 +1799,9 @@ class ForeignKey(DialectKWArgs, SchemaItem):
return "%s.%s" % (table_name, colname)
elif self._table_column is not None:
return "%s.%s" % (
- self._table_column.table.fullname, self._table_column.key)
+ self._table_column.table.fullname,
+ self._table_column.key,
+ )
else:
return self._colspec
@@ -1756,12 +1842,12 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def _column_tokens(self):
"""parse a string-based _colspec into its component parts."""
- m = self._get_colspec().split('.')
+ m = self._get_colspec().split(".")
if m is None:
raise exc.ArgumentError(
- "Invalid foreign key column specification: %s" %
- self._colspec)
- if (len(m) == 1):
+ "Invalid foreign key column specification: %s" % self._colspec
+ )
+ if len(m) == 1:
tname = m.pop()
colname = None
else:
@@ -1777,8 +1863,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
# indirectly related -- Ticket #594. This assumes that '.'
# will never appear *within* any component of the FK.
- if (len(m) > 0):
- schema = '.'.join(m)
+ if len(m) > 0:
+ schema = ".".join(m)
else:
schema = None
return schema, tname, colname
@@ -1787,12 +1873,14 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if self.parent is None:
raise exc.InvalidRequestError(
"this ForeignKey object does not yet have a "
- "parent Column associated with it.")
+ "parent Column associated with it."
+ )
elif self.parent.table is None:
raise exc.InvalidRequestError(
"this ForeignKey's parent column is not yet associated "
- "with a Table.")
+ "with a Table."
+ )
parenttable = self.parent.table
@@ -1817,7 +1905,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
return parenttable, tablekey, colname
def _link_to_col_by_colstring(self, parenttable, table, colname):
- if not hasattr(self.constraint, '_referred_table'):
+ if not hasattr(self.constraint, "_referred_table"):
self.constraint._referred_table = table
else:
assert self.constraint._referred_table is table
@@ -1843,9 +1931,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
raise exc.NoReferencedColumnError(
"Could not initialize target column "
"for ForeignKey '%s' on table '%s': "
- "table '%s' has no column named '%s'" %
- (self._colspec, parenttable.name, table.name, key),
- table.name, key)
+ "table '%s' has no column named '%s'"
+ % (self._colspec, parenttable.name, table.name, key),
+ table.name,
+ key,
+ )
self._set_target_column(_column)
@@ -1861,6 +1951,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def set_type(fk):
if fk.parent.type._isnull:
fk.parent.type = column.type
+
self.parent._setup_on_memoized_fks(set_type)
self.column = column
@@ -1888,21 +1979,25 @@ class ForeignKey(DialectKWArgs, SchemaItem):
raise exc.NoReferencedTableError(
"Foreign key associated with column '%s' could not find "
"table '%s' with which to generate a "
- "foreign key to target column '%s'" %
- (self.parent, tablekey, colname),
- tablekey)
+ "foreign key to target column '%s'"
+ % (self.parent, tablekey, colname),
+ tablekey,
+ )
elif parenttable.key not in parenttable.metadata:
raise exc.InvalidRequestError(
"Table %s is no longer associated with its "
- "parent MetaData" % parenttable)
+ "parent MetaData" % parenttable
+ )
else:
raise exc.NoReferencedColumnError(
"Could not initialize target column for "
"ForeignKey '%s' on table '%s': "
- "table '%s' has no column named '%s'" % (
- self._colspec, parenttable.name, tablekey, colname),
- tablekey, colname)
- elif hasattr(self._colspec, '__clause_element__'):
+ "table '%s' has no column named '%s'"
+ % (self._colspec, parenttable.name, tablekey, colname),
+ tablekey,
+ colname,
+ )
+ elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
return _column
else:
@@ -1912,7 +2007,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def _set_parent(self, column):
if self.parent is not None and self.parent is not column:
raise exc.InvalidRequestError(
- "This ForeignKey already has a parent !")
+ "This ForeignKey already has a parent !"
+ )
self.parent = column
self.parent.foreign_keys.add(self)
self.parent._on_table_attach(self._set_table)
@@ -1935,9 +2031,14 @@ class ForeignKey(DialectKWArgs, SchemaItem):
# on the hosting Table when attached to the Table.
if self.constraint is None and isinstance(table, Table):
self.constraint = ForeignKeyConstraint(
- [], [], use_alter=self.use_alter, name=self.name,
- onupdate=self.onupdate, ondelete=self.ondelete,
- deferrable=self.deferrable, initially=self.initially,
+ [],
+ [],
+ use_alter=self.use_alter,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ deferrable=self.deferrable,
+ initially=self.initially,
match=self.match,
**self._unvalidated_dialect_kw
)
@@ -1953,13 +2054,12 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if table_key in parenttable.metadata.tables:
table = parenttable.metadata.tables[table_key]
try:
- self._link_to_col_by_colstring(
- parenttable, table, colname)
+ self._link_to_col_by_colstring(parenttable, table, colname)
except exc.NoReferencedColumnError:
# this is OK, we'll try later
pass
parenttable.metadata._fk_memos[fk_key].append(self)
- elif hasattr(self._colspec, '__clause_element__'):
+ elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
self._set_target_column(_column)
else:
@@ -1971,7 +2071,8 @@ class _NotAColumnExpr(object):
def _not_a_column_expr(self):
raise exc.InvalidRequestError(
"This %s cannot be used directly "
- "as a column expression." % self.__class__.__name__)
+ "as a column expression." % self.__class__.__name__
+ )
__clause_element__ = self_group = lambda self: self._not_a_column_expr()
_from_objects = property(lambda self: self._not_a_column_expr())
@@ -1980,7 +2081,7 @@ class _NotAColumnExpr(object):
class DefaultGenerator(_NotAColumnExpr, SchemaItem):
"""Base class for column *default* values."""
- __visit_name__ = 'default_generator'
+ __visit_name__ = "default_generator"
is_sequence = False
is_server_default = False
@@ -2007,7 +2108,7 @@ class DefaultGenerator(_NotAColumnExpr, SchemaItem):
@property
def bind(self):
"""Return the connectable associated with this default."""
- if getattr(self, 'column', None) is not None:
+ if getattr(self, "column", None) is not None:
return self.column.table.bind
else:
return None
@@ -2064,7 +2165,8 @@ class ColumnDefault(DefaultGenerator):
super(ColumnDefault, self).__init__(**kwargs)
if isinstance(arg, FetchedValue):
raise exc.ArgumentError(
- "ColumnDefault may not be a server-side default type.")
+ "ColumnDefault may not be a server-side default type."
+ )
if util.callable(arg):
arg = self._maybe_wrap_callable(arg)
self.arg = arg
@@ -2079,9 +2181,11 @@ class ColumnDefault(DefaultGenerator):
@util.memoized_property
def is_scalar(self):
- return not self.is_callable and \
- not self.is_clause_element and \
- not self.is_sequence
+ return (
+ not self.is_callable
+ and not self.is_clause_element
+ and not self.is_sequence
+ )
@util.memoized_property
@util.dependencies("sqlalchemy.sql.sqltypes")
@@ -2114,17 +2218,19 @@ class ColumnDefault(DefaultGenerator):
else:
raise exc.ArgumentError(
"ColumnDefault Python function takes zero or one "
- "positional arguments")
+ "positional arguments"
+ )
def _visit_name(self):
if self.for_update:
return "column_onupdate"
else:
return "column_default"
+
__visit_name__ = property(_visit_name)
def __repr__(self):
- return "ColumnDefault(%r)" % (self.arg, )
+ return "ColumnDefault(%r)" % (self.arg,)
class Sequence(DefaultGenerator):
@@ -2157,15 +2263,29 @@ class Sequence(DefaultGenerator):
"""
- __visit_name__ = 'sequence'
+ __visit_name__ = "sequence"
is_sequence = True
- def __init__(self, name, start=None, increment=None, minvalue=None,
- maxvalue=None, nominvalue=None, nomaxvalue=None, cycle=None,
- schema=None, cache=None, order=None, optional=False,
- quote=None, metadata=None, quote_schema=None,
- for_update=False):
+ def __init__(
+ self,
+ name,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ schema=None,
+ cache=None,
+ order=None,
+ optional=False,
+ quote=None,
+ metadata=None,
+ quote_schema=None,
+ for_update=False,
+ ):
"""Construct a :class:`.Sequence` object.
:param name: The name of the sequence.
@@ -2353,27 +2473,22 @@ class Sequence(DefaultGenerator):
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
"""Drops this sequence from the database."""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
def _not_a_column_expr(self):
raise exc.InvalidRequestError(
"This %s cannot be used directly "
"as a column expression. Use func.next_value(sequence) "
"to produce a 'next value' function that's usable "
- "as a column element."
- % self.__class__.__name__)
-
+ "as a column element." % self.__class__.__name__
+ )
@inspection._self_inspects
@@ -2396,6 +2511,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
:ref:`triggered_columns`
"""
+
is_server_default = True
reflected = False
has_argument = False
@@ -2412,7 +2528,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
def _clone(self, for_update):
n = self.__class__.__new__(self.__class__)
n.__dict__.update(self.__dict__)
- n.__dict__.pop('column', None)
+ n.__dict__.pop("column", None)
n.for_update = for_update
return n
@@ -2452,16 +2568,15 @@ class DefaultClause(FetchedValue):
has_argument = True
def __init__(self, arg, for_update=False, _reflected=False):
- util.assert_arg_type(arg, (util.string_types[0],
- ClauseElement,
- TextClause), 'arg')
+ util.assert_arg_type(
+ arg, (util.string_types[0], ClauseElement, TextClause), "arg"
+ )
super(DefaultClause, self).__init__(for_update)
self.arg = arg
self.reflected = _reflected
def __repr__(self):
- return "DefaultClause(%r, for_update=%r)" % \
- (self.arg, self.for_update)
+ return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
class PassiveDefault(DefaultClause):
@@ -2471,10 +2586,13 @@ class PassiveDefault(DefaultClause):
:class:`.PassiveDefault` is deprecated.
Use :class:`.DefaultClause`.
"""
- @util.deprecated("0.6",
- ":class:`.PassiveDefault` is deprecated. "
- "Use :class:`.DefaultClause`.",
- False)
+
+ @util.deprecated(
+ "0.6",
+ ":class:`.PassiveDefault` is deprecated. "
+ "Use :class:`.DefaultClause`.",
+ False,
+ )
def __init__(self, *arg, **kw):
DefaultClause.__init__(self, *arg, **kw)
@@ -2482,11 +2600,18 @@ class PassiveDefault(DefaultClause):
class Constraint(DialectKWArgs, SchemaItem):
"""A table-level SQL constraint."""
- __visit_name__ = 'constraint'
-
- def __init__(self, name=None, deferrable=None, initially=None,
- _create_rule=None, info=None, _type_bound=False,
- **dialect_kw):
+ __visit_name__ = "constraint"
+
+ def __init__(
+ self,
+ name=None,
+ deferrable=None,
+ initially=None,
+ _create_rule=None,
+ info=None,
+ _type_bound=False,
+ **dialect_kw
+ ):
r"""Create a SQL constraint.
:param name:
@@ -2548,7 +2673,8 @@ class Constraint(DialectKWArgs, SchemaItem):
pass
raise exc.InvalidRequestError(
"This constraint is not bound to a table. Did you "
- "mean to call table.append_constraint(constraint) ?")
+ "mean to call table.append_constraint(constraint) ?"
+ )
def _set_parent(self, parent):
self.parent = parent
@@ -2559,7 +2685,7 @@ class Constraint(DialectKWArgs, SchemaItem):
def _to_schema_column(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
if not isinstance(element, Column):
raise exc.ArgumentError("schema.Column object expected")
@@ -2567,9 +2693,9 @@ def _to_schema_column(element):
def _to_schema_column_or_string(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
- if not isinstance(element, util.string_types + (ColumnElement, )):
+ if not isinstance(element, util.string_types + (ColumnElement,)):
msg = "Element %r is not a string name or column element"
raise exc.ArgumentError(msg % element)
return element
@@ -2588,11 +2714,12 @@ class ColumnCollectionMixin(object):
_allow_multiple_tables = False
def __init__(self, *columns, **kw):
- _autoattach = kw.pop('_autoattach', True)
- self._column_flag = kw.pop('_column_flag', False)
+ _autoattach = kw.pop("_autoattach", True)
+ self._column_flag = kw.pop("_column_flag", False)
self.columns = ColumnCollection()
- self._pending_colargs = [_to_schema_column_or_string(c)
- for c in columns]
+ self._pending_colargs = [
+ _to_schema_column_or_string(c) for c in columns
+ ]
if _autoattach and self._pending_colargs:
self._check_attach()
@@ -2601,7 +2728,7 @@ class ColumnCollectionMixin(object):
for expr in expressions:
strname = None
column = None
- if hasattr(expr, '__clause_element__'):
+ if hasattr(expr, "__clause_element__"):
expr = expr.__clause_element__()
if not isinstance(expr, (ColumnElement, TextClause)):
@@ -2609,21 +2736,16 @@ class ColumnCollectionMixin(object):
strname = expr
else:
cols = []
- visitors.traverse(expr, {}, {'column': cols.append})
+ visitors.traverse(expr, {}, {"column": cols.append})
if cols:
column = cols[0]
add_element = column if column is not None else strname
yield expr, column, strname, add_element
def _check_attach(self, evt=False):
- col_objs = [
- c for c in self._pending_colargs
- if isinstance(c, Column)
- ]
+ col_objs = [c for c in self._pending_colargs if isinstance(c, Column)]
- cols_w_table = [
- c for c in col_objs if isinstance(c.table, Table)
- ]
+ cols_w_table = [c for c in col_objs if isinstance(c.table, Table)]
cols_wo_table = set(col_objs).difference(cols_w_table)
@@ -2636,6 +2758,7 @@ class ColumnCollectionMixin(object):
# columns are specified as strings.
has_string_cols = set(self._pending_colargs).difference(col_objs)
if not has_string_cols:
+
def _col_attached(column, table):
# this isinstance() corresponds with the
# isinstance() above; only want to count Table-bound
@@ -2644,6 +2767,7 @@ class ColumnCollectionMixin(object):
cols_wo_table.discard(column)
if not cols_wo_table:
self._check_attach(evt=True)
+
self._cols_wo_table = cols_wo_table
for col in cols_wo_table:
col._on_table_attach(_col_attached)
@@ -2659,9 +2783,11 @@ class ColumnCollectionMixin(object):
others = [c for c in columns[1:] if c.table is not table]
if others:
raise exc.ArgumentError(
- "Column(s) %s are not part of table '%s'." %
- (", ".join("'%s'" % c for c in others),
- table.description)
+ "Column(s) %s are not part of table '%s'."
+ % (
+ ", ".join("'%s'" % c for c in others),
+ table.description,
+ )
)
def _set_parent(self, table):
@@ -2694,11 +2820,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
arguments are propagated to the :class:`.Constraint` superclass.
"""
- _autoattach = kw.pop('_autoattach', True)
- _column_flag = kw.pop('_column_flag', False)
+ _autoattach = kw.pop("_autoattach", True)
+ _column_flag = kw.pop("_column_flag", False)
Constraint.__init__(self, **kw)
ColumnCollectionMixin.__init__(
- self, *columns, _autoattach=_autoattach, _column_flag=_column_flag)
+ self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
+ )
columns = None
"""A :class:`.ColumnCollection` representing the set of columns
@@ -2714,8 +2841,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
return x in self.columns
def copy(self, **kw):
- c = self.__class__(name=self.name, deferrable=self.deferrable,
- initially=self.initially, *self.columns.keys())
+ c = self.__class__(
+ name=self.name,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ *self.columns.keys()
+ )
return self._schema_item_copy(c)
def contains_column(self, col):
@@ -2747,9 +2878,19 @@ class CheckConstraint(ColumnCollectionConstraint):
_allow_multiple_tables = True
- def __init__(self, sqltext, name=None, deferrable=None,
- initially=None, table=None, info=None, _create_rule=None,
- _autoattach=True, _type_bound=False, **kw):
+ def __init__(
+ self,
+ sqltext,
+ name=None,
+ deferrable=None,
+ initially=None,
+ table=None,
+ info=None,
+ _create_rule=None,
+ _autoattach=True,
+ _type_bound=False,
+ **kw
+ ):
r"""Construct a CHECK constraint.
:param sqltext:
@@ -2781,14 +2922,19 @@ class CheckConstraint(ColumnCollectionConstraint):
self.sqltext = _literal_as_text(sqltext, warn=False)
columns = []
- visitors.traverse(self.sqltext, {}, {'column': columns.append})
-
- super(CheckConstraint, self).\
- __init__(
- name=name, deferrable=deferrable,
- initially=initially, _create_rule=_create_rule, info=info,
- _type_bound=_type_bound, _autoattach=_autoattach,
- *columns, **kw)
+ visitors.traverse(self.sqltext, {}, {"column": columns.append})
+
+ super(CheckConstraint, self).__init__(
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ _create_rule=_create_rule,
+ info=info,
+ _type_bound=_type_bound,
+ _autoattach=_autoattach,
+ *columns,
+ **kw
+ )
if table is not None:
self._set_parent_with_dispatch(table)
@@ -2797,22 +2943,24 @@ class CheckConstraint(ColumnCollectionConstraint):
return "check_constraint"
else:
return "column_check_constraint"
+
__visit_name__ = property(__visit_name__)
def copy(self, target_table=None, **kw):
if target_table is not None:
- sqltext = _copy_expression(
- self.sqltext, self.table, target_table)
+ sqltext = _copy_expression(self.sqltext, self.table, target_table)
else:
sqltext = self.sqltext
- c = CheckConstraint(sqltext,
- name=self.name,
- initially=self.initially,
- deferrable=self.deferrable,
- _create_rule=self._create_rule,
- table=target_table,
- _autoattach=False,
- _type_bound=self._type_bound)
+ c = CheckConstraint(
+ sqltext,
+ name=self.name,
+ initially=self.initially,
+ deferrable=self.deferrable,
+ _create_rule=self._create_rule,
+ table=target_table,
+ _autoattach=False,
+ _type_bound=self._type_bound,
+ )
return self._schema_item_copy(c)
@@ -2828,12 +2976,25 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
Examples of foreign key configuration are in :ref:`metadata_foreignkeys`.
"""
- __visit_name__ = 'foreign_key_constraint'
- def __init__(self, columns, refcolumns, name=None, onupdate=None,
- ondelete=None, deferrable=None, initially=None,
- use_alter=False, link_to_name=False, match=None,
- table=None, info=None, **dialect_kw):
+ __visit_name__ = "foreign_key_constraint"
+
+ def __init__(
+ self,
+ columns,
+ refcolumns,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ use_alter=False,
+ link_to_name=False,
+ match=None,
+ table=None,
+ info=None,
+ **dialect_kw
+ ):
r"""Construct a composite-capable FOREIGN KEY.
:param columns: A sequence of local column names. The named columns
@@ -2905,8 +3066,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
"""
Constraint.__init__(
- self, name=name, deferrable=deferrable, initially=initially,
- info=info, **dialect_kw)
+ self,
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ info=info,
+ **dialect_kw
+ )
self.onupdate = onupdate
self.ondelete = ondelete
self.link_to_name = link_to_name
@@ -2927,7 +3093,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
raise exc.ArgumentError(
"ForeignKeyConstraint number "
"of constrained columns must match the number of "
- "referenced columns.")
+ "referenced columns."
+ )
# standalone ForeignKeyConstraint - create
# associated ForeignKey objects which will be applied to hosted
@@ -2946,7 +3113,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
deferrable=self.deferrable,
initially=self.initially,
**self.dialect_kwargs
- ) for refcol in refcolumns
+ )
+ for refcol in refcolumns
]
ColumnCollectionMixin.__init__(self, *columns)
@@ -2978,9 +3146,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
@property
def _elements(self):
# legacy - provide a dictionary view of (column_key, fk)
- return util.OrderedDict(
- zip(self.column_keys, self.elements)
- )
+ return util.OrderedDict(zip(self.column_keys, self.elements))
@property
def _referred_schema(self):
@@ -3004,18 +3170,14 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
return self.elements[0].column.table
def _validate_dest_table(self, table):
- table_keys = set([elem._table_key()
- for elem in self.elements])
+ table_keys = set([elem._table_key() for elem in self.elements])
if None not in table_keys and len(table_keys) > 1:
elem0, elem1 = sorted(table_keys)[0:2]
raise exc.ArgumentError(
- 'ForeignKeyConstraint on %s(%s) refers to '
- 'multiple remote tables: %s and %s' % (
- table.fullname,
- self._col_description,
- elem0,
- elem1
- ))
+ "ForeignKeyConstraint on %s(%s) refers to "
+ "multiple remote tables: %s and %s"
+ % (table.fullname, self._col_description, elem0, elem1)
+ )
@property
def column_keys(self):
@@ -3034,8 +3196,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
return self.columns.keys()
else:
return [
- col.key if isinstance(col, ColumnElement)
- else str(col) for col in self._pending_colargs
+ col.key if isinstance(col, ColumnElement) else str(col)
+ for col in self._pending_colargs
]
@property
@@ -3051,11 +3213,11 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
raise exc.ArgumentError(
"Can't create ForeignKeyConstraint "
"on table '%s': no column "
- "named '%s' is present." % (table.description, ke.args[0]))
+ "named '%s' is present." % (table.description, ke.args[0])
+ )
for col, fk in zip(self.columns, self.elements):
- if not hasattr(fk, 'parent') or \
- fk.parent is not col:
+ if not hasattr(fk, "parent") or fk.parent is not col:
fk._set_parent_with_dispatch(col)
self._validate_dest_table(table)
@@ -3063,13 +3225,16 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
def copy(self, schema=None, target_table=None, **kw):
fkc = ForeignKeyConstraint(
[x.parent.key for x in self.elements],
- [x._get_colspec(
- schema=schema,
- table_name=target_table.name
- if target_table is not None
- and x._table_key() == x.parent.table.key
- else None)
- for x in self.elements],
+ [
+ x._get_colspec(
+ schema=schema,
+ table_name=target_table.name
+ if target_table is not None
+ and x._table_key() == x.parent.table.key
+ else None,
+ )
+ for x in self.elements
+ ],
name=self.name,
onupdate=self.onupdate,
ondelete=self.ondelete,
@@ -3077,11 +3242,9 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
deferrable=self.deferrable,
initially=self.initially,
link_to_name=self.link_to_name,
- match=self.match
+ match=self.match,
)
- for self_fk, other_fk in zip(
- self.elements,
- fkc.elements):
+ for self_fk, other_fk in zip(self.elements, fkc.elements):
self_fk._schema_item_copy(other_fk)
return self._schema_item_copy(fkc)
@@ -3160,10 +3323,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
"""
- __visit_name__ = 'primary_key_constraint'
+ __visit_name__ = "primary_key_constraint"
def __init__(self, *columns, **kw):
- self._implicit_generated = kw.pop('_implicit_generated', False)
+ self._implicit_generated = kw.pop("_implicit_generated", False)
super(PrimaryKeyConstraint, self).__init__(*columns, **kw)
def _set_parent(self, table):
@@ -3175,18 +3338,21 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
table.constraints.add(self)
table_pks = [c for c in table.c if c.primary_key]
- if self.columns and table_pks and \
- set(table_pks) != set(self.columns.values()):
+ if (
+ self.columns
+ and table_pks
+ and set(table_pks) != set(self.columns.values())
+ ):
util.warn(
"Table '%s' specifies columns %s as primary_key=True, "
"not matching locally specified columns %s; setting the "
"current primary key columns to %s. This warning "
- "may become an exception in a future release" %
- (
+ "may become an exception in a future release"
+ % (
table.name,
", ".join("'%s'" % c.name for c in table_pks),
", ".join("'%s'" % c.name for c in self.columns),
- ", ".join("'%s'" % c.name for c in self.columns)
+ ", ".join("'%s'" % c.name for c in self.columns),
)
)
table_pks[:] = []
@@ -3241,28 +3407,28 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
@util.memoized_property
def _autoincrement_column(self):
-
def _validate_autoinc(col, autoinc_true):
if col.type._type_affinity is None or not issubclass(
- col.type._type_affinity,
- type_api.INTEGERTYPE._type_affinity):
+ col.type._type_affinity, type_api.INTEGERTYPE._type_affinity
+ ):
if autoinc_true:
raise exc.ArgumentError(
"Column type %s on column '%s' is not "
- "compatible with autoincrement=True" % (
- col.type,
- col
- ))
+ "compatible with autoincrement=True" % (col.type, col)
+ )
else:
return False
- elif not isinstance(col.default, (type(None), Sequence)) and \
- not autoinc_true:
- return False
+ elif (
+ not isinstance(col.default, (type(None), Sequence))
+ and not autoinc_true
+ ):
+ return False
elif col.server_default is not None and not autoinc_true:
return False
- elif (
- col.foreign_keys and col.autoincrement
- not in (True, 'ignore_fk')):
+ elif col.foreign_keys and col.autoincrement not in (
+ True,
+ "ignore_fk",
+ ):
return False
return True
@@ -3272,10 +3438,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
if col.autoincrement is True:
_validate_autoinc(col, True)
return col
- elif (
- col.autoincrement in ('auto', 'ignore_fk') and
- _validate_autoinc(col, False)
- ):
+ elif col.autoincrement in (
+ "auto",
+ "ignore_fk",
+ ) and _validate_autoinc(col, False):
return col
else:
@@ -3286,8 +3452,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
if autoinc is not None:
raise exc.ArgumentError(
"Only one Column may be marked "
- "autoincrement=True, found both %s and %s." %
- (col.name, autoinc.name)
+ "autoincrement=True, found both %s and %s."
+ % (col.name, autoinc.name)
)
else:
autoinc = col
@@ -3304,7 +3470,7 @@ class UniqueConstraint(ColumnCollectionConstraint):
UniqueConstraint.
"""
- __visit_name__ = 'unique_constraint'
+ __visit_name__ = "unique_constraint"
class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
@@ -3382,7 +3548,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
"""
- __visit_name__ = 'index'
+ __visit_name__ = "index"
def __init__(self, name, *expressions, **kw):
r"""Construct an index object.
@@ -3420,30 +3586,35 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
columns = []
processed_expressions = []
- for expr, column, strname, add_element in self.\
- _extract_col_expression_collection(expressions):
+ for (
+ expr,
+ column,
+ strname,
+ add_element,
+ ) in self._extract_col_expression_collection(expressions):
if add_element is not None:
columns.append(add_element)
processed_expressions.append(expr)
self.expressions = processed_expressions
self.name = quoted_name(name, kw.pop("quote", None))
- self.unique = kw.pop('unique', False)
- _column_flag = kw.pop('_column_flag', False)
- if 'info' in kw:
- self.info = kw.pop('info')
+ self.unique = kw.pop("unique", False)
+ _column_flag = kw.pop("_column_flag", False)
+ if "info" in kw:
+ self.info = kw.pop("info")
# TODO: consider "table" argument being public, but for
# the purpose of the fix here, it starts as private.
- if '_table' in kw:
- table = kw.pop('_table')
+ if "_table" in kw:
+ table = kw.pop("_table")
self._validate_dialect_kwargs(kw)
# will call _set_parent() if table-bound column
# objects are present
ColumnCollectionMixin.__init__(
- self, *columns, _column_flag=_column_flag)
+ self, *columns, _column_flag=_column_flag
+ )
if table is not None:
self._set_parent(table)
@@ -3454,20 +3625,17 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
if self.table is not None and table is not self.table:
raise exc.ArgumentError(
"Index '%s' is against table '%s', and "
- "cannot be associated with table '%s'." % (
- self.name,
- self.table.description,
- table.description
- )
+ "cannot be associated with table '%s'."
+ % (self.name, self.table.description, table.description)
)
self.table = table
table.indexes.add(self)
self.expressions = [
- expr if isinstance(expr, ClauseElement)
- else colexpr
- for expr, colexpr in util.zip_longest(self.expressions,
- self.columns)
+ expr if isinstance(expr, ClauseElement) else colexpr
+ for expr, colexpr in util.zip_longest(
+ self.expressions, self.columns
+ )
]
@property
@@ -3506,17 +3674,16 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
bind._run_visitor(ddl.SchemaDropper, self)
def __repr__(self):
- return 'Index(%s)' % (
+ return "Index(%s)" % (
", ".join(
- [repr(self.name)] +
- [repr(e) for e in self.expressions] +
- (self.unique and ["unique=True"] or [])
- ))
+ [repr(self.name)]
+ + [repr(e) for e in self.expressions]
+ + (self.unique and ["unique=True"] or [])
+ )
+ )
-DEFAULT_NAMING_CONVENTION = util.immutabledict({
- "ix": 'ix_%(column_0_label)s'
-})
+DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"})
class MetaData(SchemaItem):
@@ -3542,13 +3709,17 @@ class MetaData(SchemaItem):
"""
- __visit_name__ = 'metadata'
-
- def __init__(self, bind=None, reflect=False, schema=None,
- quote_schema=None,
- naming_convention=DEFAULT_NAMING_CONVENTION,
- info=None
- ):
+ __visit_name__ = "metadata"
+
+ def __init__(
+ self,
+ bind=None,
+ reflect=False,
+ schema=None,
+ quote_schema=None,
+ naming_convention=DEFAULT_NAMING_CONVENTION,
+ info=None,
+ ):
"""Create a new MetaData object.
:param bind:
@@ -3712,12 +3883,15 @@ class MetaData(SchemaItem):
self.bind = bind
if reflect:
- util.warn_deprecated("reflect=True is deprecate; please "
- "use the reflect() method.")
+ util.warn_deprecated(
+ "reflect=True is deprecate; please "
+ "use the reflect() method."
+ )
if not bind:
raise exc.ArgumentError(
"A bind must be supplied in conjunction "
- "with reflect=True")
+ "with reflect=True"
+ )
self.reflect()
tables = None
@@ -3735,7 +3909,7 @@ class MetaData(SchemaItem):
"""
def __repr__(self):
- return 'MetaData(bind=%r)' % self.bind
+ return "MetaData(bind=%r)" % self.bind
def __contains__(self, table_or_key):
if not isinstance(table_or_key, util.string_types):
@@ -3755,27 +3929,32 @@ class MetaData(SchemaItem):
for fk in removed.foreign_keys:
fk._remove_from_metadata(self)
if self._schemas:
- self._schemas = set([t.schema
- for t in self.tables.values()
- if t.schema is not None])
+ self._schemas = set(
+ [
+ t.schema
+ for t in self.tables.values()
+ if t.schema is not None
+ ]
+ )
def __getstate__(self):
- return {'tables': self.tables,
- 'schema': self.schema,
- 'schemas': self._schemas,
- 'sequences': self._sequences,
- 'fk_memos': self._fk_memos,
- 'naming_convention': self.naming_convention
- }
+ return {
+ "tables": self.tables,
+ "schema": self.schema,
+ "schemas": self._schemas,
+ "sequences": self._sequences,
+ "fk_memos": self._fk_memos,
+ "naming_convention": self.naming_convention,
+ }
def __setstate__(self, state):
- self.tables = state['tables']
- self.schema = state['schema']
- self.naming_convention = state['naming_convention']
+ self.tables = state["tables"]
+ self.schema = state["schema"]
+ self.naming_convention = state["naming_convention"]
self._bind = None
- self._sequences = state['sequences']
- self._schemas = state['schemas']
- self._fk_memos = state['fk_memos']
+ self._sequences = state["sequences"]
+ self._schemas = state["schemas"]
+ self._fk_memos = state["fk_memos"]
def is_bound(self):
"""True if this MetaData is bound to an Engine or Connection."""
@@ -3805,10 +3984,11 @@ class MetaData(SchemaItem):
def _bind_to(self, url, bind):
"""Bind this MetaData to an Engine, Connection, string or URL."""
- if isinstance(bind, util.string_types + (url.URL, )):
+ if isinstance(bind, util.string_types + (url.URL,)):
self._bind = sqlalchemy.create_engine(bind)
else:
self._bind = bind
+
bind = property(bind, _bind_to)
def clear(self):
@@ -3858,12 +4038,20 @@ class MetaData(SchemaItem):
"""
- return ddl.sort_tables(sorted(self.tables.values(), key=lambda t: t.key))
+ return ddl.sort_tables(
+ sorted(self.tables.values(), key=lambda t: t.key)
+ )
- def reflect(self, bind=None, schema=None, views=False, only=None,
- extend_existing=False,
- autoload_replace=True,
- **dialect_kwargs):
+ def reflect(
+ self,
+ bind=None,
+ schema=None,
+ views=False,
+ only=None,
+ extend_existing=False,
+ autoload_replace=True,
+ **dialect_kwargs
+ ):
r"""Load all available table definitions from the database.
Automatically creates ``Table`` entries in this ``MetaData`` for any
@@ -3926,11 +4114,11 @@ class MetaData(SchemaItem):
with bind.connect() as conn:
reflect_opts = {
- 'autoload': True,
- 'autoload_with': conn,
- 'extend_existing': extend_existing,
- 'autoload_replace': autoload_replace,
- '_extend_on': set()
+ "autoload": True,
+ "autoload_with": conn,
+ "extend_existing": extend_existing,
+ "autoload_replace": autoload_replace,
+ "_extend_on": set(),
}
reflect_opts.update(dialect_kwargs)
@@ -3939,42 +4127,49 @@ class MetaData(SchemaItem):
schema = self.schema
if schema is not None:
- reflect_opts['schema'] = schema
+ reflect_opts["schema"] = schema
available = util.OrderedSet(
- bind.engine.table_names(schema, connection=conn))
+ bind.engine.table_names(schema, connection=conn)
+ )
if views:
- available.update(
- bind.dialect.get_view_names(conn, schema)
- )
+ available.update(bind.dialect.get_view_names(conn, schema))
if schema is not None:
- available_w_schema = util.OrderedSet(["%s.%s" % (schema, name)
- for name in available])
+ available_w_schema = util.OrderedSet(
+ ["%s.%s" % (schema, name) for name in available]
+ )
else:
available_w_schema = available
current = set(self.tables)
if only is None:
- load = [name for name, schname in
- zip(available, available_w_schema)
- if extend_existing or schname not in current]
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if extend_existing or schname not in current
+ ]
elif util.callable(only):
- load = [name for name, schname in
- zip(available, available_w_schema)
- if (extend_existing or schname not in current)
- and only(name, self)]
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if (extend_existing or schname not in current)
+ and only(name, self)
+ ]
else:
missing = [name for name in only if name not in available]
if missing:
- s = schema and (" schema '%s'" % schema) or ''
+ s = schema and (" schema '%s'" % schema) or ""
raise exc.InvalidRequestError(
- 'Could not reflect: requested table(s) not available '
- 'in %r%s: (%s)' %
- (bind.engine, s, ', '.join(missing)))
- load = [name for name in only if extend_existing or
- name not in current]
+ "Could not reflect: requested table(s) not available "
+ "in %r%s: (%s)" % (bind.engine, s, ", ".join(missing))
+ )
+ load = [
+ name
+ for name in only
+ if extend_existing or name not in current
+ ]
for name in load:
try:
@@ -3989,11 +4184,12 @@ class MetaData(SchemaItem):
See :class:`.DDLEvents`.
"""
+
def adapt_listener(target, connection, **kw):
- tables = kw['tables']
+ tables = kw["tables"]
listener(event, target, connection, tables=tables)
- event.listen(self, "" + event_name.replace('-', '_'), adapt_listener)
+ event.listen(self, "" + event_name.replace("-", "_"), adapt_listener)
def create_all(self, bind=None, tables=None, checkfirst=True):
"""Create all tables stored in this metadata.
@@ -4017,10 +4213,9 @@ class MetaData(SchemaItem):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst,
- tables=tables)
+ bind._run_visitor(
+ ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables
+ )
def drop_all(self, bind=None, tables=None, checkfirst=True):
"""Drop all tables stored in this metadata.
@@ -4044,10 +4239,9 @@ class MetaData(SchemaItem):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst,
- tables=tables)
+ bind._run_visitor(
+ ddl.SchemaDropper, self, checkfirst=checkfirst, tables=tables
+ )
class ThreadLocalMetaData(MetaData):
@@ -4064,7 +4258,7 @@ class ThreadLocalMetaData(MetaData):
"""
- __visit_name__ = 'metadata'
+ __visit_name__ = "metadata"
def __init__(self):
"""Construct a ThreadLocalMetaData."""
@@ -4080,13 +4274,13 @@ class ThreadLocalMetaData(MetaData):
string or URL to automatically create a basic Engine for this bind
with ``create_engine()``."""
- return getattr(self.context, '_engine', None)
+ return getattr(self.context, "_engine", None)
@util.dependencies("sqlalchemy.engine.url")
def _bind_to(self, url, bind):
"""Bind to a Connectable in the caller's thread."""
- if isinstance(bind, util.string_types + (url.URL, )):
+ if isinstance(bind, util.string_types + (url.URL,)):
try:
self.context._engine = self.__engines[bind]
except KeyError:
@@ -4104,14 +4298,16 @@ class ThreadLocalMetaData(MetaData):
def is_bound(self):
"""True if there is a bind for this thread."""
- return (hasattr(self.context, '_engine') and
- self.context._engine is not None)
+ return (
+ hasattr(self.context, "_engine")
+ and self.context._engine is not None
+ )
def dispose(self):
"""Dispose all bound engines, in all thread contexts."""
for e in self.__engines.values():
- if hasattr(e, 'dispose'):
+ if hasattr(e, "dispose"):
e.dispose()
@@ -4128,22 +4324,25 @@ class _SchemaTranslateMap(object):
"""
- __slots__ = 'map_', '__call__', 'hash_key', 'is_default'
+
+ __slots__ = "map_", "__call__", "hash_key", "is_default"
_default_schema_getter = operator.attrgetter("schema")
def __init__(self, map_):
self.map_ = map_
if map_ is not None:
+
def schema_for_object(obj):
effective_schema = self._default_schema_getter(obj)
effective_schema = obj._translate_schema(
- effective_schema, map_)
+ effective_schema, map_
+ )
return effective_schema
+
self.__call__ = schema_for_object
self.hash_key = ";".join(
- "%s=%s" % (k, map_[k])
- for k in sorted(map_, key=str)
+ "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str)
)
self.is_default = False
else:
@@ -4160,6 +4359,6 @@ class _SchemaTranslateMap(object):
else:
return _SchemaTranslateMap(map_)
+
_default_schema_map = _SchemaTranslateMap(None)
_schema_getter = _SchemaTranslateMap._schema_getter
-
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index f64f152c4..1f1800514 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -10,15 +10,39 @@ SQL tables and derived rowsets.
"""
-from .elements import ClauseElement, TextClause, ClauseList, \
- and_, Grouping, UnaryExpression, literal_column, BindParameter
-from .elements import _clone, \
- _literal_as_text, _interpret_as_column_or_from, _expand_cloned,\
- _select_iterables, _anonymous_label, _clause_element_as_expr,\
- _cloned_intersection, _cloned_difference, True_, \
- _literal_as_label_reference, _literal_and_labels_as_label_reference
-from .base import Immutable, Executable, _generative, \
- ColumnCollection, ColumnSet, _from_objects, Generative
+from .elements import (
+ ClauseElement,
+ TextClause,
+ ClauseList,
+ and_,
+ Grouping,
+ UnaryExpression,
+ literal_column,
+ BindParameter,
+)
+from .elements import (
+ _clone,
+ _literal_as_text,
+ _interpret_as_column_or_from,
+ _expand_cloned,
+ _select_iterables,
+ _anonymous_label,
+ _clause_element_as_expr,
+ _cloned_intersection,
+ _cloned_difference,
+ True_,
+ _literal_as_label_reference,
+ _literal_and_labels_as_label_reference,
+)
+from .base import (
+ Immutable,
+ Executable,
+ _generative,
+ ColumnCollection,
+ ColumnSet,
+ _from_objects,
+ Generative,
+)
from . import type_api
from .. import inspection
from .. import util
@@ -40,7 +64,8 @@ def _interpret_as_from(element):
"Textual SQL FROM expression %(expr)r should be "
"explicitly declared as text(%(expr)r), "
"or use table(%(expr)r) for more specificity",
- {"expr": util.ellipses_string(element)})
+ {"expr": util.ellipses_string(element)},
+ )
return TextClause(util.text_type(element))
try:
@@ -73,7 +98,7 @@ def _offset_or_limit_clause(element, name=None, type_=None):
"""
if element is None:
return None
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif isinstance(element, Visitable):
return element
@@ -97,7 +122,8 @@ def _offset_or_limit_clause_asint(clause, attrname):
except AttributeError:
raise exc.CompileError(
"This SELECT structure does not use a simple "
- "integer value for %s" % attrname)
+ "integer value for %s" % attrname
+ )
else:
return util.asint(value)
@@ -225,12 +251,14 @@ def tablesample(selectable, sampling, name=None, seed=None):
"""
return _interpret_as_from(selectable).tablesample(
- sampling, name=name, seed=seed)
+ sampling, name=name, seed=seed
+ )
class Selectable(ClauseElement):
"""mark a class as being selectable"""
- __visit_name__ = 'selectable'
+
+ __visit_name__ = "selectable"
is_selectable = True
@@ -265,15 +293,17 @@ class HasPrefixes(object):
limit rendering of this prefix to only that dialect.
"""
- dialect = kw.pop('dialect', None)
+ dialect = kw.pop("dialect", None)
if kw:
- raise exc.ArgumentError("Unsupported argument(s): %s" %
- ",".join(kw))
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
self._setup_prefixes(expr, dialect)
def _setup_prefixes(self, prefixes, dialect=None):
self._prefixes = self._prefixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in prefixes])
+ [(_literal_as_text(p, warn=False), dialect) for p in prefixes]
+ )
class HasSuffixes(object):
@@ -301,15 +331,17 @@ class HasSuffixes(object):
limit rendering of this suffix to only that dialect.
"""
- dialect = kw.pop('dialect', None)
+ dialect = kw.pop("dialect", None)
if kw:
- raise exc.ArgumentError("Unsupported argument(s): %s" %
- ",".join(kw))
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
self._setup_suffixes(expr, dialect)
def _setup_suffixes(self, suffixes, dialect=None):
self._suffixes = self._suffixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in suffixes])
+ [(_literal_as_text(p, warn=False), dialect) for p in suffixes]
+ )
class FromClause(Selectable):
@@ -330,7 +362,8 @@ class FromClause(Selectable):
"""
- __visit_name__ = 'fromclause'
+
+ __visit_name__ = "fromclause"
named_with_column = False
_hide_froms = []
@@ -359,13 +392,14 @@ class FromClause(Selectable):
_memoized_property = util.group_expirable_memoized_property(["_columns"])
@util.deprecated(
- '1.1',
+ "1.1",
message="``FromClause.count()`` is deprecated. Counting "
"rows requires that the correct column expression and "
"accommodations for joins, DISTINCT, etc. must be made, "
"otherwise results may not be what's expected. "
"Please use an appropriate ``func.count()`` expression "
- "directly.")
+ "directly.",
+ )
@util.dependencies("sqlalchemy.sql.functions")
def count(self, functions, whereclause=None, **params):
"""return a SELECT COUNT generated against this
@@ -392,10 +426,11 @@ class FromClause(Selectable):
else:
col = list(self.columns)[0]
return Select(
- [functions.func.count(col).label('tbl_row_count')],
+ [functions.func.count(col).label("tbl_row_count")],
whereclause,
from_obj=[self],
- **params)
+ **params
+ )
def select(self, whereclause=None, **params):
"""return a SELECT of this :class:`.FromClause`.
@@ -603,8 +638,9 @@ class FromClause(Selectable):
def embedded(expanded_proxy_set, target_set):
for t in target_set.difference(expanded_proxy_set):
- if not set(_expand_cloned([t])
- ).intersection(expanded_proxy_set):
+ if not set(_expand_cloned([t])).intersection(
+ expanded_proxy_set
+ ):
return False
return True
@@ -617,8 +653,10 @@ class FromClause(Selectable):
for c in cols:
expanded_proxy_set = set(_expand_cloned(c.proxy_set))
i = target_set.intersection(expanded_proxy_set)
- if i and (not require_embedded
- or embedded(expanded_proxy_set, target_set)):
+ if i and (
+ not require_embedded
+ or embedded(expanded_proxy_set, target_set)
+ ):
if col is None:
# no corresponding column yet, pick this one.
@@ -646,12 +684,20 @@ class FromClause(Selectable):
col_distance = util.reduce(
operator.add,
- [sc._annotations.get('weight', 1) for sc in
- col.proxy_set if sc.shares_lineage(column)])
+ [
+ sc._annotations.get("weight", 1)
+ for sc in col.proxy_set
+ if sc.shares_lineage(column)
+ ],
+ )
c_distance = util.reduce(
operator.add,
- [sc._annotations.get('weight', 1) for sc in
- c.proxy_set if sc.shares_lineage(column)])
+ [
+ sc._annotations.get("weight", 1)
+ for sc in c.proxy_set
+ if sc.shares_lineage(column)
+ ],
+ )
if c_distance < col_distance:
col, intersect = c, i
return col
@@ -663,7 +709,7 @@ class FromClause(Selectable):
Used primarily for error message formatting.
"""
- return getattr(self, 'name', self.__class__.__name__ + " object")
+ return getattr(self, "name", self.__class__.__name__ + " object")
def _reset_exported(self):
"""delete memoized collections when a FromClause is cloned."""
@@ -683,7 +729,7 @@ class FromClause(Selectable):
"""
- if '_columns' not in self.__dict__:
+ if "_columns" not in self.__dict__:
self._init_collections()
self._populate_column_collection()
return self._columns.as_immutable()
@@ -706,14 +752,16 @@ class FromClause(Selectable):
self._populate_column_collection()
return self.foreign_keys
- c = property(attrgetter('columns'),
- doc="An alias for the :attr:`.columns` attribute.")
- _select_iterable = property(attrgetter('columns'))
+ c = property(
+ attrgetter("columns"),
+ doc="An alias for the :attr:`.columns` attribute.",
+ )
+ _select_iterable = property(attrgetter("columns"))
def _init_collections(self):
- assert '_columns' not in self.__dict__
- assert 'primary_key' not in self.__dict__
- assert 'foreign_keys' not in self.__dict__
+ assert "_columns" not in self.__dict__
+ assert "primary_key" not in self.__dict__
+ assert "foreign_keys" not in self.__dict__
self._columns = ColumnCollection()
self.primary_key = ColumnSet()
@@ -721,7 +769,7 @@ class FromClause(Selectable):
@property
def _cols_populated(self):
- return '_columns' in self.__dict__
+ return "_columns" in self.__dict__
def _populate_column_collection(self):
"""Called on subclasses to establish the .c collection.
@@ -758,8 +806,7 @@ class FromClause(Selectable):
"""
if not self._cols_populated:
return None
- elif (column.key in self.columns and
- self.columns[column.key] is column):
+ elif column.key in self.columns and self.columns[column.key] is column:
return column
else:
return None
@@ -780,7 +827,8 @@ class Join(FromClause):
:meth:`.FromClause.join`
"""
- __visit_name__ = 'join'
+
+ __visit_name__ = "join"
_is_join = True
@@ -829,8 +877,9 @@ class Join(FromClause):
return cls(left, right, onclause, isouter=True, full=full)
@classmethod
- def _create_join(cls, left, right, onclause=None, isouter=False,
- full=False):
+ def _create_join(
+ cls, left, right, onclause=None, isouter=False, full=False
+ ):
"""Produce a :class:`.Join` object, given two :class:`.FromClause`
expressions.
@@ -882,26 +931,34 @@ class Join(FromClause):
self.left.description,
id(self.left),
self.right.description,
- id(self.right))
+ id(self.right),
+ )
def is_derived_from(self, fromclause):
- return fromclause is self or \
- self.left.is_derived_from(fromclause) or \
- self.right.is_derived_from(fromclause)
+ return (
+ fromclause is self
+ or self.left.is_derived_from(fromclause)
+ or self.right.is_derived_from(fromclause)
+ )
def self_group(self, against=None):
return FromGrouping(self)
@util.dependencies("sqlalchemy.sql.util")
def _populate_column_collection(self, sqlutil):
- columns = [c for c in self.left.columns] + \
- [c for c in self.right.columns]
+ columns = [c for c in self.left.columns] + [
+ c for c in self.right.columns
+ ]
- self.primary_key.extend(sqlutil.reduce_columns(
- (c for c in columns if c.primary_key), self.onclause))
+ self.primary_key.extend(
+ sqlutil.reduce_columns(
+ (c for c in columns if c.primary_key), self.onclause
+ )
+ )
self._columns.update((col._label, col) for col in columns)
- self.foreign_keys.update(itertools.chain(
- *[col.foreign_keys for col in columns]))
+ self.foreign_keys.update(
+ itertools.chain(*[col.foreign_keys for col in columns])
+ )
def _refresh_for_new_column(self, column):
col = self.left._refresh_for_new_column(column)
@@ -933,9 +990,14 @@ class Join(FromClause):
return self._join_condition(left, right, a_subset=left_right)
@classmethod
- def _join_condition(cls, a, b, ignore_nonexistent_tables=False,
- a_subset=None,
- consider_as_foreign_keys=None):
+ def _join_condition(
+ cls,
+ a,
+ b,
+ ignore_nonexistent_tables=False,
+ a_subset=None,
+ consider_as_foreign_keys=None,
+ ):
"""create a join condition between two tables or selectables.
e.g.::
@@ -963,26 +1025,31 @@ class Join(FromClause):
"""
constraints = cls._joincond_scan_left_right(
- a, a_subset, b, consider_as_foreign_keys)
+ a, a_subset, b, consider_as_foreign_keys
+ )
if len(constraints) > 1:
cls._joincond_trim_constraints(
- a, b, constraints, consider_as_foreign_keys)
+ a, b, constraints, consider_as_foreign_keys
+ )
if len(constraints) == 0:
if isinstance(b, FromGrouping):
- hint = " Perhaps you meant to convert the right side to a "\
+ hint = (
+ " Perhaps you meant to convert the right side to a "
"subquery using alias()?"
+ )
else:
hint = ""
raise exc.NoForeignKeysError(
"Can't find any foreign key relationships "
- "between '%s' and '%s'.%s" %
- (a.description, b.description, hint))
+ "between '%s' and '%s'.%s"
+ % (a.description, b.description, hint)
+ )
crit = [(x == y) for x, y in list(constraints.values())[0]]
if len(crit) == 1:
- return (crit[0])
+ return crit[0]
else:
return and_(*crit)
@@ -994,24 +1061,30 @@ class Join(FromClause):
left_right = None
constraints = cls._joincond_scan_left_right(
- a=left, b=right, a_subset=left_right,
- consider_as_foreign_keys=consider_as_foreign_keys)
+ a=left,
+ b=right,
+ a_subset=left_right,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
return bool(constraints)
@classmethod
def _joincond_scan_left_right(
- cls, a, a_subset, b, consider_as_foreign_keys):
+ cls, a, a_subset, b, consider_as_foreign_keys
+ ):
constraints = collections.defaultdict(list)
for left in (a_subset, a):
if left is None:
continue
for fk in sorted(
- b.foreign_keys,
- key=lambda fk: fk.parent._creation_order):
- if consider_as_foreign_keys is not None and \
- fk.parent not in consider_as_foreign_keys:
+ b.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
continue
try:
col = fk.get_referent(left)
@@ -1025,10 +1098,12 @@ class Join(FromClause):
constraints[fk.constraint].append((col, fk.parent))
if left is not b:
for fk in sorted(
- left.foreign_keys,
- key=lambda fk: fk.parent._creation_order):
- if consider_as_foreign_keys is not None and \
- fk.parent not in consider_as_foreign_keys:
+ left.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
continue
try:
col = fk.get_referent(b)
@@ -1046,14 +1121,16 @@ class Join(FromClause):
@classmethod
def _joincond_trim_constraints(
- cls, a, b, constraints, consider_as_foreign_keys):
+ cls, a, b, constraints, consider_as_foreign_keys
+ ):
# more than one constraint matched. narrow down the list
# to include just those FKCs that match exactly to
# "consider_as_foreign_keys".
if consider_as_foreign_keys:
for const in list(constraints):
if set(f.parent for f in const.elements) != set(
- consider_as_foreign_keys):
+ consider_as_foreign_keys
+ ):
del constraints[const]
# if still multiple constraints, but
@@ -1070,8 +1147,8 @@ class Join(FromClause):
"tables have more than one foreign key "
"constraint relationship between them. "
"Please specify the 'onclause' of this "
- "join explicitly." % (a.description, b.description))
-
+ "join explicitly." % (a.description, b.description)
+ )
def select(self, whereclause=None, **kwargs):
r"""Create a :class:`.Select` from this :class:`.Join`.
@@ -1200,27 +1277,37 @@ class Join(FromClause):
"""
if flat:
assert name is None, "Can't send name argument with flat"
- left_a, right_a = self.left.alias(flat=True), \
- self.right.alias(flat=True)
- adapter = sqlutil.ClauseAdapter(left_a).\
- chain(sqlutil.ClauseAdapter(right_a))
+ left_a, right_a = (
+ self.left.alias(flat=True),
+ self.right.alias(flat=True),
+ )
+ adapter = sqlutil.ClauseAdapter(left_a).chain(
+ sqlutil.ClauseAdapter(right_a)
+ )
- return left_a.join(right_a, adapter.traverse(self.onclause),
- isouter=self.isouter, full=self.full)
+ return left_a.join(
+ right_a,
+ adapter.traverse(self.onclause),
+ isouter=self.isouter,
+ full=self.full,
+ )
else:
return self.select(use_labels=True, correlate=False).alias(name)
@property
def _hide_froms(self):
- return itertools.chain(*[_from_objects(x.left, x.right)
- for x in self._cloned_set])
+ return itertools.chain(
+ *[_from_objects(x.left, x.right) for x in self._cloned_set]
+ )
@property
def _from_objects(self):
- return [self] + \
- self.onclause._from_objects + \
- self.left._from_objects + \
- self.right._from_objects
+ return (
+ [self]
+ + self.onclause._from_objects
+ + self.left._from_objects
+ + self.right._from_objects
+ )
class Alias(FromClause):
@@ -1236,7 +1323,7 @@ class Alias(FromClause):
"""
- __visit_name__ = 'alias'
+ __visit_name__ = "alias"
named_with_column = True
_is_from_container = True
@@ -1252,15 +1339,16 @@ class Alias(FromClause):
self.element = selectable
if name is None:
if self.original.named_with_column:
- name = getattr(self.original, 'name', None)
- name = _anonymous_label('%%(%d %s)s' % (id(self), name
- or 'anon'))
+ name = getattr(self.original, "name", None)
+ name = _anonymous_label("%%(%d %s)s" % (id(self), name or "anon"))
self.name = name
def self_group(self, against=None):
- if isinstance(against, CompoundSelect) and \
- isinstance(self.original, Select) and \
- self.original._needs_parens_for_grouping():
+ if (
+ isinstance(against, CompoundSelect)
+ and isinstance(self.original, Select)
+ and self.original._needs_parens_for_grouping()
+ ):
return FromGrouping(self)
return super(Alias, self).self_group(against=against)
@@ -1270,14 +1358,15 @@ class Alias(FromClause):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
def as_scalar(self):
try:
return self.element.as_scalar()
except AttributeError:
- raise AttributeError("Element %s does not support "
- "'as_scalar()'" % self.element)
+ raise AttributeError(
+ "Element %s does not support " "'as_scalar()'" % self.element
+ )
def is_derived_from(self, fromclause):
if fromclause in self._cloned_set:
@@ -1344,7 +1433,7 @@ class Lateral(Alias):
"""
- __visit_name__ = 'lateral'
+ __visit_name__ = "lateral"
_is_lateral = True
@@ -1363,11 +1452,9 @@ class TableSample(Alias):
"""
- __visit_name__ = 'tablesample'
+ __visit_name__ = "tablesample"
- def __init__(self, selectable, sampling,
- name=None,
- seed=None):
+ def __init__(self, selectable, sampling, name=None, seed=None):
self.sampling = sampling
self.seed = seed
super(TableSample, self).__init__(selectable, name=name)
@@ -1390,14 +1477,18 @@ class CTE(Generative, HasSuffixes, Alias):
.. versionadded:: 0.7.6
"""
- __visit_name__ = 'cte'
-
- def __init__(self, selectable,
- name=None,
- recursive=False,
- _cte_alias=None,
- _restates=frozenset(),
- _suffixes=None):
+
+ __visit_name__ = "cte"
+
+ def __init__(
+ self,
+ selectable,
+ name=None,
+ recursive=False,
+ _cte_alias=None,
+ _restates=frozenset(),
+ _suffixes=None,
+ ):
self.recursive = recursive
self._cte_alias = _cte_alias
self._restates = _restates
@@ -1409,9 +1500,9 @@ class CTE(Generative, HasSuffixes, Alias):
super(CTE, self)._copy_internals(clone, **kw)
if self._cte_alias is not None:
self._cte_alias = clone(self._cte_alias, **kw)
- self._restates = frozenset([
- clone(elem, **kw) for elem in self._restates
- ])
+ self._restates = frozenset(
+ [clone(elem, **kw) for elem in self._restates]
+ )
@util.dependencies("sqlalchemy.sql.dml")
def _populate_column_collection(self, dml):
@@ -1428,7 +1519,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=name,
recursive=self.recursive,
_cte_alias=self,
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
def union(self, other):
@@ -1437,7 +1528,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=self.name,
recursive=self.recursive,
_restates=self._restates.union([self]),
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
def union_all(self, other):
@@ -1446,7 +1537,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=self.name,
recursive=self.recursive,
_restates=self._restates.union([self]),
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
@@ -1620,7 +1711,8 @@ class HasCTE(object):
class FromGrouping(FromClause):
"""Represent a grouping of a FROM clause"""
- __visit_name__ = 'grouping'
+
+ __visit_name__ = "grouping"
def __init__(self, element):
self.element = element
@@ -1651,7 +1743,7 @@ class FromGrouping(FromClause):
return self.element._hide_froms
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -1664,10 +1756,10 @@ class FromGrouping(FromClause):
return getattr(self.element, attr)
def __getstate__(self):
- return {'element': self.element}
+ return {"element": self.element}
def __setstate__(self, state):
- self.element = state['element']
+ self.element = state["element"]
class TableClause(Immutable, FromClause):
@@ -1699,7 +1791,7 @@ class TableClause(Immutable, FromClause):
"""
- __visit_name__ = 'table'
+ __visit_name__ = "table"
named_with_column = True
@@ -1744,7 +1836,7 @@ class TableClause(Immutable, FromClause):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
def append_column(self, c):
self._columns[c.key] = c
@@ -1773,7 +1865,8 @@ class TableClause(Immutable, FromClause):
@util.dependencies("sqlalchemy.sql.dml")
def update(
- self, dml, whereclause=None, values=None, inline=False, **kwargs):
+ self, dml, whereclause=None, values=None, inline=False, **kwargs
+ ):
"""Generate an :func:`.update` construct against this
:class:`.TableClause`.
@@ -1785,8 +1878,13 @@ class TableClause(Immutable, FromClause):
"""
- return dml.Update(self, whereclause=whereclause,
- values=values, inline=inline, **kwargs)
+ return dml.Update(
+ self,
+ whereclause=whereclause,
+ values=values,
+ inline=inline,
+ **kwargs
+ )
@util.dependencies("sqlalchemy.sql.dml")
def delete(self, dml, whereclause=None, **kwargs):
@@ -1809,7 +1907,6 @@ class TableClause(Immutable, FromClause):
class ForUpdateArg(ClauseElement):
-
@classmethod
def parse_legacy_select(self, arg):
"""Parse the for_update argument of :func:`.select`.
@@ -1836,11 +1933,11 @@ class ForUpdateArg(ClauseElement):
return None
nowait = read = False
- if arg == 'nowait':
+ if arg == "nowait":
nowait = True
- elif arg == 'read':
+ elif arg == "read":
read = True
- elif arg == 'read_nowait':
+ elif arg == "read_nowait":
read = nowait = True
elif arg is not True:
raise exc.ArgumentError("Unknown for_update argument: %r" % arg)
@@ -1860,12 +1957,12 @@ class ForUpdateArg(ClauseElement):
def __eq__(self, other):
return (
- isinstance(other, ForUpdateArg) and
- other.nowait == self.nowait and
- other.read == self.read and
- other.skip_locked == self.skip_locked and
- other.key_share == self.key_share and
- other.of is self.of
+ isinstance(other, ForUpdateArg)
+ and other.nowait == self.nowait
+ and other.read == self.read
+ and other.skip_locked == self.skip_locked
+ and other.key_share == self.key_share
+ and other.of is self.of
)
def __hash__(self):
@@ -1876,8 +1973,13 @@ class ForUpdateArg(ClauseElement):
self.of = [clone(col, **kw) for col in self.of]
def __init__(
- self, nowait=False, read=False, of=None,
- skip_locked=False, key_share=False):
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
"""Represents arguments specified to :meth:`.Select.for_update`.
.. versionadded:: 0.9.0
@@ -1889,8 +1991,9 @@ class ForUpdateArg(ClauseElement):
self.skip_locked = skip_locked
self.key_share = key_share
if of is not None:
- self.of = [_interpret_as_column_or_from(elem)
- for elem in util.to_list(of)]
+ self.of = [
+ _interpret_as_column_or_from(elem) for elem in util.to_list(of)
+ ]
else:
self.of = None
@@ -1930,17 +2033,20 @@ class SelectBase(HasCTE, Executable, FromClause):
return self.as_scalar().label(name)
@_generative
- @util.deprecated('0.6',
- message="``autocommit()`` is deprecated. Use "
- ":meth:`.Executable.execution_options` with the "
- "'autocommit' flag.")
+ @util.deprecated(
+ "0.6",
+ message="``autocommit()`` is deprecated. Use "
+ ":meth:`.Executable.execution_options` with the "
+ "'autocommit' flag.",
+ )
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to
True.
"""
- self._execution_options = \
- self._execution_options.union({'autocommit': True})
+ self._execution_options = self._execution_options.union(
+ {"autocommit": True}
+ )
def _generate(self):
"""Override the default _generate() method to also clear out
@@ -1973,34 +2079,38 @@ class GenerativeSelect(SelectBase):
used for other SELECT-like objects, e.g. :class:`.TextAsFrom`.
"""
+
_order_by_clause = ClauseList()
_group_by_clause = ClauseList()
_limit_clause = None
_offset_clause = None
_for_update_arg = None
- def __init__(self,
- use_labels=False,
- for_update=False,
- limit=None,
- offset=None,
- order_by=None,
- group_by=None,
- bind=None,
- autocommit=None):
+ def __init__(
+ self,
+ use_labels=False,
+ for_update=False,
+ limit=None,
+ offset=None,
+ order_by=None,
+ group_by=None,
+ bind=None,
+ autocommit=None,
+ ):
self.use_labels = use_labels
if for_update is not False:
- self._for_update_arg = (ForUpdateArg.
- parse_legacy_select(for_update))
+ self._for_update_arg = ForUpdateArg.parse_legacy_select(for_update)
if autocommit is not None:
- util.warn_deprecated('autocommit on select() is '
- 'deprecated. Use .execution_options(a'
- 'utocommit=True)')
- self._execution_options = \
- self._execution_options.union(
- {'autocommit': autocommit})
+ util.warn_deprecated(
+ "autocommit on select() is "
+ "deprecated. Use .execution_options(a"
+ "utocommit=True)"
+ )
+ self._execution_options = self._execution_options.union(
+ {"autocommit": autocommit}
+ )
if limit is not None:
self._limit_clause = _offset_or_limit_clause(limit)
if offset is not None:
@@ -2010,11 +2120,13 @@ class GenerativeSelect(SelectBase):
if order_by is not None:
self._order_by_clause = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_and_labels_as_label_reference)
+ _literal_as_text=_literal_and_labels_as_label_reference
+ )
if group_by is not None:
self._group_by_clause = ClauseList(
*util.to_list(group_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
@property
def for_update(self):
@@ -2030,8 +2142,14 @@ class GenerativeSelect(SelectBase):
self._for_update_arg = ForUpdateArg.parse_legacy_select(value)
@_generative
- def with_for_update(self, nowait=False, read=False, of=None,
- skip_locked=False, key_share=False):
+ def with_for_update(
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
"""Specify a ``FOR UPDATE`` clause for this :class:`.GenerativeSelect`.
E.g.::
@@ -2079,9 +2197,13 @@ class GenerativeSelect(SelectBase):
.. versionadded:: 1.1.0
"""
- self._for_update_arg = ForUpdateArg(nowait=nowait, read=read, of=of,
- skip_locked=skip_locked,
- key_share=key_share)
+ self._for_update_arg = ForUpdateArg(
+ nowait=nowait,
+ read=read,
+ of=of,
+ skip_locked=skip_locked,
+ key_share=key_share,
+ )
@_generative
def apply_labels(self):
@@ -2209,11 +2331,12 @@ class GenerativeSelect(SelectBase):
if len(clauses) == 1 and clauses[0] is None:
self._order_by_clause = ClauseList()
else:
- if getattr(self, '_order_by_clause', None) is not None:
+ if getattr(self, "_order_by_clause", None) is not None:
clauses = list(self._order_by_clause) + list(clauses)
self._order_by_clause = ClauseList(
*clauses,
- _literal_as_text=_literal_and_labels_as_label_reference)
+ _literal_as_text=_literal_and_labels_as_label_reference
+ )
def append_group_by(self, *clauses):
"""Append the given GROUP BY criterion applied to this selectable.
@@ -2228,10 +2351,11 @@ class GenerativeSelect(SelectBase):
if len(clauses) == 1 and clauses[0] is None:
self._group_by_clause = ClauseList()
else:
- if getattr(self, '_group_by_clause', None) is not None:
+ if getattr(self, "_group_by_clause", None) is not None:
clauses = list(self._group_by_clause) + list(clauses)
self._group_by_clause = ClauseList(
- *clauses, _literal_as_text=_literal_as_label_reference)
+ *clauses, _literal_as_text=_literal_as_label_reference
+ )
@property
def _label_resolve_dict(self):
@@ -2265,19 +2389,19 @@ class CompoundSelect(GenerativeSelect):
"""
- __visit_name__ = 'compound_select'
+ __visit_name__ = "compound_select"
- UNION = util.symbol('UNION')
- UNION_ALL = util.symbol('UNION ALL')
- EXCEPT = util.symbol('EXCEPT')
- EXCEPT_ALL = util.symbol('EXCEPT ALL')
- INTERSECT = util.symbol('INTERSECT')
- INTERSECT_ALL = util.symbol('INTERSECT ALL')
+ UNION = util.symbol("UNION")
+ UNION_ALL = util.symbol("UNION ALL")
+ EXCEPT = util.symbol("EXCEPT")
+ EXCEPT_ALL = util.symbol("EXCEPT ALL")
+ INTERSECT = util.symbol("INTERSECT")
+ INTERSECT_ALL = util.symbol("INTERSECT ALL")
_is_from_container = True
def __init__(self, keyword, *selects, **kwargs):
- self._auto_correlate = kwargs.pop('correlate', False)
+ self._auto_correlate = kwargs.pop("correlate", False)
self.keyword = keyword
self.selects = []
@@ -2291,12 +2415,16 @@ class CompoundSelect(GenerativeSelect):
numcols = len(s.c._all_columns)
elif len(s.c._all_columns) != numcols:
raise exc.ArgumentError(
- 'All selectables passed to '
- 'CompoundSelect must have identical numbers of '
- 'columns; select #%d has %d columns, select '
- '#%d has %d' %
- (1, len(self.selects[0].c._all_columns),
- n + 1, len(s.c._all_columns))
+ "All selectables passed to "
+ "CompoundSelect must have identical numbers of "
+ "columns; select #%d has %d columns, select "
+ "#%d has %d"
+ % (
+ 1,
+ len(self.selects[0].c._all_columns),
+ n + 1,
+ len(s.c._all_columns),
+ )
)
self.selects.append(s.self_group(against=self))
@@ -2305,9 +2433,7 @@ class CompoundSelect(GenerativeSelect):
@property
def _label_resolve_dict(self):
- d = dict(
- (c.key, c) for c in self.c
- )
+ d = dict((c.key, c) for c in self.c)
return d, d, d
@classmethod
@@ -2416,8 +2542,7 @@ class CompoundSelect(GenerativeSelect):
:func:`select`.
"""
- return CompoundSelect(
- CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
+ return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
def _scalar_type(self):
return self.selects[0]._scalar_type()
@@ -2445,8 +2570,10 @@ class CompoundSelect(GenerativeSelect):
# those fks too.
proxy = cols[0]._make_proxy(
- self, name=cols[0]._label if self.use_labels else None,
- key=cols[0]._key_label if self.use_labels else None)
+ self,
+ name=cols[0]._label if self.use_labels else None,
+ key=cols[0]._key_label if self.use_labels else None,
+ )
# hand-construct the "_proxies" collection to include all
# derived columns place a 'weight' annotation corresponding
@@ -2455,7 +2582,8 @@ class CompoundSelect(GenerativeSelect):
# conflicts
proxy._proxies = [
- c._annotate({'weight': i + 1}) for (i, c) in enumerate(cols)]
+ c._annotate({"weight": i + 1}) for (i, c) in enumerate(cols)
+ ]
def _refresh_for_new_column(self, column):
for s in self.selects:
@@ -2464,25 +2592,32 @@ class CompoundSelect(GenerativeSelect):
if not self._cols_populated:
return None
- raise NotImplementedError("CompoundSelect constructs don't support "
- "addition of columns to underlying "
- "selectables")
+ raise NotImplementedError(
+ "CompoundSelect constructs don't support "
+ "addition of columns to underlying "
+ "selectables"
+ )
def _copy_internals(self, clone=_clone, **kw):
super(CompoundSelect, self)._copy_internals(clone, **kw)
self._reset_exported()
self.selects = [clone(s, **kw) for s in self.selects]
- if hasattr(self, '_col_map'):
+ if hasattr(self, "_col_map"):
del self._col_map
for attr in (
- '_order_by_clause', '_group_by_clause', '_for_update_arg'):
+ "_order_by_clause",
+ "_group_by_clause",
+ "_for_update_arg",
+ ):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
def get_children(self, column_collections=True, **kwargs):
- return (column_collections and list(self.c) or []) \
- + [self._order_by_clause, self._group_by_clause] \
+ return (
+ (column_collections and list(self.c) or [])
+ + [self._order_by_clause, self._group_by_clause]
+ list(self.selects)
+ )
def bind(self):
if self._bind:
@@ -2496,6 +2631,7 @@ class CompoundSelect(GenerativeSelect):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@@ -2504,7 +2640,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
- __visit_name__ = 'select'
+ __visit_name__ = "select"
_prefixes = ()
_suffixes = ()
@@ -2517,16 +2653,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
_memoized_property = SelectBase._memoized_property
_is_select = True
- def __init__(self,
- columns=None,
- whereclause=None,
- from_obj=None,
- distinct=False,
- having=None,
- correlate=True,
- prefixes=None,
- suffixes=None,
- **kwargs):
+ def __init__(
+ self,
+ columns=None,
+ whereclause=None,
+ from_obj=None,
+ distinct=False,
+ having=None,
+ correlate=True,
+ prefixes=None,
+ suffixes=None,
+ **kwargs
+ ):
"""Construct a new :class:`.Select`.
Similar functionality is also available via the
@@ -2729,22 +2867,23 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._distinct = True
else:
self._distinct = [
- _literal_as_text(e)
- for e in util.to_list(distinct)
+ _literal_as_text(e) for e in util.to_list(distinct)
]
if from_obj is not None:
self._from_obj = util.OrderedSet(
- _interpret_as_from(f)
- for f in util.to_list(from_obj))
+ _interpret_as_from(f) for f in util.to_list(from_obj)
+ )
else:
self._from_obj = util.OrderedSet()
try:
cols_present = bool(columns)
except TypeError:
- raise exc.ArgumentError("columns argument to select() must "
- "be a Python list or other iterable")
+ raise exc.ArgumentError(
+ "columns argument to select() must "
+ "be a Python list or other iterable"
+ )
if cols_present:
self._raw_columns = []
@@ -2757,14 +2896,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._raw_columns = []
if whereclause is not None:
- self._whereclause = _literal_as_text(
- whereclause).self_group(against=operators._asbool)
+ self._whereclause = _literal_as_text(whereclause).self_group(
+ against=operators._asbool
+ )
else:
self._whereclause = None
if having is not None:
- self._having = _literal_as_text(
- having).self_group(against=operators._asbool)
+ self._having = _literal_as_text(having).self_group(
+ against=operators._asbool
+ )
else:
self._having = None
@@ -2789,12 +2930,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
for item in itertools.chain(
_from_objects(*self._raw_columns),
_from_objects(self._whereclause)
- if self._whereclause is not None else (),
- self._from_obj
+ if self._whereclause is not None
+ else (),
+ self._from_obj,
):
if item is self:
raise exc.InvalidRequestError(
- "select() construct refers to itself as a FROM")
+ "select() construct refers to itself as a FROM"
+ )
if translate and item in translate:
item = translate[item]
if not seen.intersection(item._cloned_set):
@@ -2803,8 +2946,9 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return froms
- def _get_display_froms(self, explicit_correlate_froms=None,
- implicit_correlate_froms=None):
+ def _get_display_froms(
+ self, explicit_correlate_froms=None, implicit_correlate_froms=None
+ ):
"""Return the full list of 'from' clauses to be displayed.
Takes into account a set of existing froms which may be
@@ -2815,17 +2959,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
froms = self._froms
- toremove = set(itertools.chain(*[
- _expand_cloned(f._hide_froms)
- for f in froms]))
+ toremove = set(
+ itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms])
+ )
if toremove:
# if we're maintaining clones of froms,
# add the copies out to the toremove list. only include
# clones that are lexical equivalents.
if self._from_cloned:
toremove.update(
- self._from_cloned[f] for f in
- toremove.intersection(self._from_cloned)
+ self._from_cloned[f]
+ for f in toremove.intersection(self._from_cloned)
if self._from_cloned[f]._is_lexical_equivalent(f)
)
# filter out to FROM clauses not in the list,
@@ -2836,41 +2980,53 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
to_correlate = self._correlate
if to_correlate:
froms = [
- f for f in froms if f not in
- _cloned_intersection(
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(
_cloned_intersection(
- froms, explicit_correlate_froms or ()),
- to_correlate
+ froms, explicit_correlate_froms or ()
+ ),
+ to_correlate,
)
]
if self._correlate_except is not None:
froms = [
- f for f in froms if f not in
- _cloned_difference(
+ f
+ for f in froms
+ if f
+ not in _cloned_difference(
_cloned_intersection(
- froms, explicit_correlate_froms or ()),
- self._correlate_except
+ froms, explicit_correlate_froms or ()
+ ),
+ self._correlate_except,
)
]
- if self._auto_correlate and \
- implicit_correlate_froms and \
- len(froms) > 1:
+ if (
+ self._auto_correlate
+ and implicit_correlate_froms
+ and len(froms) > 1
+ ):
froms = [
- f for f in froms if f not in
- _cloned_intersection(froms, implicit_correlate_froms)
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(froms, implicit_correlate_froms)
]
if not len(froms):
- raise exc.InvalidRequestError("Select statement '%s"
- "' returned no FROM clauses "
- "due to auto-correlation; "
- "specify correlate(<tables>) "
- "to control correlation "
- "manually." % self)
+ raise exc.InvalidRequestError(
+ "Select statement '%s"
+ "' returned no FROM clauses "
+ "due to auto-correlation; "
+ "specify correlate(<tables>) "
+ "to control correlation "
+ "manually." % self
+ )
return froms
@@ -2885,7 +3041,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return self._get_display_froms()
- def with_statement_hint(self, text, dialect_name='*'):
+ def with_statement_hint(self, text, dialect_name="*"):
"""add a statement hint to this :class:`.Select`.
This method is similar to :meth:`.Select.with_hint` except that
@@ -2906,7 +3062,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return self.with_hint(None, text, dialect_name)
@_generative
- def with_hint(self, selectable, text, dialect_name='*'):
+ def with_hint(self, selectable, text, dialect_name="*"):
r"""Add an indexing or other executional context hint for the given
selectable to this :class:`.Select`.
@@ -2940,17 +3096,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
if selectable is None:
- self._statement_hints += ((dialect_name, text), )
+ self._statement_hints += ((dialect_name, text),)
else:
- self._hints = self._hints.union(
- {(selectable, dialect_name): text})
+ self._hints = self._hints.union({(selectable, dialect_name): text})
@property
def type(self):
- raise exc.InvalidRequestError("Select objects don't have a type. "
- "Call as_scalar() on this Select "
- "object to return a 'scalar' version "
- "of this Select.")
+ raise exc.InvalidRequestError(
+ "Select objects don't have a type. "
+ "Call as_scalar() on this Select "
+ "object to return a 'scalar' version "
+ "of this Select."
+ )
@_memoized_property.method
def locate_all_froms(self):
@@ -2977,10 +3134,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
with_cols = dict(
(c._resolve_label or c._label or c.key, c)
for c in _select_iterables(self._raw_columns)
- if c._allow_label_resolve)
+ if c._allow_label_resolve
+ )
only_froms = dict(
- (c.key, c) for c in
- _select_iterables(self.froms) if c._allow_label_resolve)
+ (c.key, c)
+ for c in _select_iterables(self.froms)
+ if c._allow_label_resolve
+ )
only_cols = with_cols.copy()
for key, value in only_froms.items():
with_cols.setdefault(key, value)
@@ -3011,11 +3171,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
# gets cleared on each generation. previously we were "baking"
# _froms into self._from_obj.
self._from_cloned = from_cloned = dict(
- (f, clone(f, **kw)) for f in self._from_obj.union(self._froms))
+ (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)
+ )
# 3. update persistent _from_obj with the cloned versions.
- self._from_obj = util.OrderedSet(from_cloned[f] for f in
- self._from_obj)
+ self._from_obj = util.OrderedSet(
+ from_cloned[f] for f in self._from_obj
+ )
# the _correlate collection is done separately, what can happen
# here is the same item is _correlate as in _from_obj but the
@@ -3023,16 +3185,22 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
# RelationshipProperty.Comparator._criterion_exists() does
# this). Also keep _correlate liberally open with its previous
# contents, as this set is used for matching, not rendering.
- self._correlate = set(clone(f) for f in
- self._correlate).union(self._correlate)
+ self._correlate = set(clone(f) for f in self._correlate).union(
+ self._correlate
+ )
# 4. clone other things. The difficulty here is that Column
# objects are not actually cloned, and refer to their original
# .table, resulting in the wrong "from" parent after a clone
# operation. Hence _from_cloned and _from_obj supersede what is
# present here.
self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
- for attr in '_whereclause', '_having', '_order_by_clause', \
- '_group_by_clause', '_for_update_arg':
+ for attr in (
+ "_whereclause",
+ "_having",
+ "_order_by_clause",
+ "_group_by_clause",
+ "_for_update_arg",
+ ):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
@@ -3043,12 +3211,21 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
def get_children(self, column_collections=True, **kwargs):
"""return child elements as per the ClauseElement specification."""
- return (column_collections and list(self.columns) or []) + \
- self._raw_columns + list(self._froms) + \
- [x for x in
- (self._whereclause, self._having,
- self._order_by_clause, self._group_by_clause)
- if x is not None]
+ return (
+ (column_collections and list(self.columns) or [])
+ + self._raw_columns
+ + list(self._froms)
+ + [
+ x
+ for x in (
+ self._whereclause,
+ self._having,
+ self._order_by_clause,
+ self._group_by_clause,
+ )
+ if x is not None
+ ]
+ )
@_generative
def column(self, column):
@@ -3094,7 +3271,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
sqlutil.reduce_columns(
self.inner_columns,
only_synonyms=only_synonyms,
- *(self._whereclause, ) + tuple(self._from_obj)
+ *(self._whereclause,) + tuple(self._from_obj)
)
)
@@ -3307,7 +3484,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._correlate = ()
else:
self._correlate = set(self._correlate).union(
- _interpret_as_from(f) for f in fromclauses)
+ _interpret_as_from(f) for f in fromclauses
+ )
@_generative
def correlate_except(self, *fromclauses):
@@ -3349,7 +3527,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._correlate_except = ()
else:
self._correlate_except = set(self._correlate_except or ()).union(
- _interpret_as_from(f) for f in fromclauses)
+ _interpret_as_from(f) for f in fromclauses
+ )
def append_correlation(self, fromclause):
"""append the given correlation expression to this select()
@@ -3363,7 +3542,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._auto_correlate = False
self._correlate = set(self._correlate).union(
- _interpret_as_from(f) for f in fromclause)
+ _interpret_as_from(f) for f in fromclause
+ )
def append_column(self, column):
"""append the given column expression to the columns clause of this
@@ -3415,8 +3595,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
self._reset_exported()
- self._whereclause = and_(
- True_._ifnone(self._whereclause), whereclause)
+ self._whereclause = and_(True_._ifnone(self._whereclause), whereclause)
def append_having(self, having):
"""append the given expression to this select() construct's HAVING
@@ -3463,19 +3642,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return [
name_for_col(c)
- for c in util.unique_list(
- _select_iterables(self._raw_columns))
+ for c in util.unique_list(_select_iterables(self._raw_columns))
]
else:
return [
(None, c)
- for c in util.unique_list(
- _select_iterables(self._raw_columns))
+ for c in util.unique_list(_select_iterables(self._raw_columns))
]
def _populate_column_collection(self):
for name, c in self._columns_plus_names:
- if not hasattr(c, '_make_proxy'):
+ if not hasattr(c, "_make_proxy"):
continue
if name is None:
key = None
@@ -3486,9 +3663,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
else:
key = None
- c._make_proxy(self, key=key,
- name=name,
- name_is_truncatable=True)
+ c._make_proxy(self, key=key, name=name, name_is_truncatable=True)
def _refresh_for_new_column(self, column):
for fromclause in self._froms:
@@ -3501,15 +3676,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self,
name=col._label if self.use_labels else None,
key=col._key_label if self.use_labels else None,
- name_is_truncatable=True)
+ name_is_truncatable=True,
+ )
return None
return None
def _needs_parens_for_grouping(self):
return (
- self._limit_clause is not None or
- self._offset_clause is not None or
- bool(self._order_by_clause.clauses)
+ self._limit_clause is not None
+ or self._offset_clause is not None
+ or bool(self._order_by_clause.clauses)
)
def self_group(self, against=None):
@@ -3521,8 +3697,10 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
expressions and should not require explicit use.
"""
- if isinstance(against, CompoundSelect) and \
- not self._needs_parens_for_grouping():
+ if (
+ isinstance(against, CompoundSelect)
+ and not self._needs_parens_for_grouping()
+ ):
return self
return FromGrouping(self)
@@ -3586,6 +3764,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@@ -3600,9 +3779,12 @@ class ScalarSelect(Generative, Grouping):
@property
def columns(self):
- raise exc.InvalidRequestError('Scalar Select expression has no '
- 'columns; use this object directly '
- 'within a column-level expression.')
+ raise exc.InvalidRequestError(
+ "Scalar Select expression has no "
+ "columns; use this object directly "
+ "within a column-level expression."
+ )
+
c = columns
@_generative
@@ -3621,6 +3803,7 @@ class Exists(UnaryExpression):
"""Represent an ``EXISTS`` clause.
"""
+
__visit_name__ = UnaryExpression.__visit_name__
_from_objects = []
@@ -3646,12 +3829,16 @@ class Exists(UnaryExpression):
s = args[0]
else:
if not args:
- args = ([literal_column('*')],)
+ args = ([literal_column("*")],)
s = Select(*args, **kwargs).as_scalar().self_group()
- UnaryExpression.__init__(self, s, operator=operators.exists,
- type_=type_api.BOOLEANTYPE,
- wraps_column_expression=True)
+ UnaryExpression.__init__(
+ self,
+ s,
+ operator=operators.exists,
+ type_=type_api.BOOLEANTYPE,
+ wraps_column_expression=True,
+ )
def select(self, whereclause=None, **params):
return Select([self], whereclause, **params)
@@ -3706,6 +3893,7 @@ class TextAsFrom(SelectBase):
:meth:`.TextClause.columns`
"""
+
__visit_name__ = "text_as_from"
_textual = True
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index c5708940b..61fc6d3c9 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -15,10 +15,21 @@ import collections
import json
from . import elements
-from .type_api import TypeEngine, TypeDecorator, to_instance, Variant, \
- Emulated, NativeForEmulated
-from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \
- Slice, _literal_as_binds
+from .type_api import (
+ TypeEngine,
+ TypeDecorator,
+ to_instance,
+ Variant,
+ Emulated,
+ NativeForEmulated,
+)
+from .elements import (
+ quoted_name,
+ TypeCoerce as type_coerce,
+ _defer_name,
+ Slice,
+ _literal_as_binds,
+)
from .. import exc, util, processors
from .base import _bind_or_error, SchemaEventTarget
from . import operators
@@ -51,14 +62,15 @@ class _LookupExpressionAdapter(object):
def _adapt_expression(self, op, other_comparator):
othertype = other_comparator.type._type_affinity
lookup = self.type._expression_adaptations.get(
- op, self._blank_dict).get(
- othertype, self.type)
+ op, self._blank_dict
+ ).get(othertype, self.type)
if lookup is othertype:
return (op, other_comparator.type)
elif lookup is self.type._type_affinity:
return (op, self.type)
else:
return (op, to_instance(lookup))
+
comparator_factory = Comparator
@@ -68,17 +80,16 @@ class Concatenable(object):
typically strings."""
class Comparator(TypeEngine.Comparator):
-
def _adapt_expression(self, op, other_comparator):
- if (op is operators.add and
- isinstance(
- other_comparator,
- (Concatenable.Comparator, NullType.Comparator)
- )):
+ if op is operators.add and isinstance(
+ other_comparator,
+ (Concatenable.Comparator, NullType.Comparator),
+ ):
return operators.concat_op, self.expr.type
else:
return super(Concatenable.Comparator, self)._adapt_expression(
- op, other_comparator)
+ op, other_comparator
+ )
comparator_factory = Comparator
@@ -94,17 +105,15 @@ class Indexable(object):
"""
class Comparator(TypeEngine.Comparator):
-
def _setup_getitem(self, index):
raise NotImplementedError()
def __getitem__(self, index):
- adjusted_op, adjusted_right_expr, result_type = \
- self._setup_getitem(index)
+ adjusted_op, adjusted_right_expr, result_type = self._setup_getitem(
+ index
+ )
return self.operate(
- adjusted_op,
- adjusted_right_expr,
- result_type=result_type
+ adjusted_op, adjusted_right_expr, result_type=result_type
)
comparator_factory = Comparator
@@ -124,13 +133,16 @@ class String(Concatenable, TypeEngine):
"""
- __visit_name__ = 'string'
+ __visit_name__ = "string"
- def __init__(self, length=None, collation=None,
- convert_unicode=False,
- unicode_error=None,
- _warn_on_bytestring=False
- ):
+ def __init__(
+ self,
+ length=None,
+ collation=None,
+ convert_unicode=False,
+ unicode_error=None,
+ _warn_on_bytestring=False,
+ ):
"""
Create a string-holding type.
@@ -207,9 +219,10 @@ class String(Concatenable, TypeEngine):
strings from a column with varied or corrupted encodings.
"""
- if unicode_error is not None and convert_unicode != 'force':
- raise exc.ArgumentError("convert_unicode must be 'force' "
- "when unicode_error is set.")
+ if unicode_error is not None and convert_unicode != "force":
+ raise exc.ArgumentError(
+ "convert_unicode must be 'force' " "when unicode_error is set."
+ )
self.length = length
self.collation = collation
@@ -222,23 +235,29 @@ class String(Concatenable, TypeEngine):
value = value.replace("'", "''")
if dialect.identifier_preparer._double_percents:
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
return "'%s'" % value
+
return process
def bind_processor(self, dialect):
if self.convert_unicode or dialect.convert_unicode:
- if dialect.supports_unicode_binds and \
- self.convert_unicode != 'force':
+ if (
+ dialect.supports_unicode_binds
+ and self.convert_unicode != "force"
+ ):
if self._warn_on_bytestring:
+
def process(value):
if isinstance(value, util.binary_type):
util.warn_limited(
"Unicode type received non-unicode "
"bind param value %r.",
- (util.ellipses_string(value),))
+ (util.ellipses_string(value),),
+ )
return value
+
return process
else:
return None
@@ -253,29 +272,34 @@ class String(Concatenable, TypeEngine):
util.warn_limited(
"Unicode type received non-unicode bind "
"param value %r.",
- (util.ellipses_string(value),))
+ (util.ellipses_string(value),),
+ )
return value
+
return process
else:
return None
def result_processor(self, dialect, coltype):
wants_unicode = self.convert_unicode or dialect.convert_unicode
- needs_convert = wants_unicode and \
- (dialect.returns_unicode_strings is not True or
- self.convert_unicode in ('force', 'force_nocheck'))
+ needs_convert = wants_unicode and (
+ dialect.returns_unicode_strings is not True
+ or self.convert_unicode in ("force", "force_nocheck")
+ )
needs_isinstance = (
- needs_convert and
- dialect.returns_unicode_strings and
- self.convert_unicode != 'force_nocheck'
+ needs_convert
+ and dialect.returns_unicode_strings
+ and self.convert_unicode != "force_nocheck"
)
if needs_convert:
if needs_isinstance:
return processors.to_conditional_unicode_processor_factory(
- dialect.encoding, self.unicode_error)
+ dialect.encoding, self.unicode_error
+ )
else:
return processors.to_unicode_processor_factory(
- dialect.encoding, self.unicode_error)
+ dialect.encoding, self.unicode_error
+ )
else:
return None
@@ -301,7 +325,8 @@ class Text(String):
argument here, it will be rejected by others.
"""
- __visit_name__ = 'text'
+
+ __visit_name__ = "text"
class Unicode(String):
@@ -360,7 +385,7 @@ class Unicode(String):
"""
- __visit_name__ = 'unicode'
+ __visit_name__ = "unicode"
def __init__(self, length=None, **kwargs):
"""
@@ -371,8 +396,8 @@ class Unicode(String):
defaults to ``True``.
"""
- kwargs.setdefault('convert_unicode', True)
- kwargs.setdefault('_warn_on_bytestring', True)
+ kwargs.setdefault("convert_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
super(Unicode, self).__init__(length=length, **kwargs)
@@ -389,7 +414,7 @@ class UnicodeText(Text):
"""
- __visit_name__ = 'unicode_text'
+ __visit_name__ = "unicode_text"
def __init__(self, length=None, **kwargs):
"""
@@ -400,8 +425,8 @@ class UnicodeText(Text):
defaults to ``True``.
"""
- kwargs.setdefault('convert_unicode', True)
- kwargs.setdefault('_warn_on_bytestring', True)
+ kwargs.setdefault("convert_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
super(UnicodeText, self).__init__(length=length, **kwargs)
@@ -409,7 +434,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
"""A type for ``int`` integers."""
- __visit_name__ = 'integer'
+ __visit_name__ = "integer"
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
@@ -421,6 +446,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
def literal_processor(self, dialect):
def process(value):
return str(value)
+
return process
@util.memoized_property
@@ -438,18 +464,9 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
Integer: self.__class__,
Numeric: Numeric,
},
- operators.div: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
- operators.truediv: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
- operators.sub: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
+ operators.div: {Integer: self.__class__, Numeric: Numeric},
+ operators.truediv: {Integer: self.__class__, Numeric: Numeric},
+ operators.sub: {Integer: self.__class__, Numeric: Numeric},
}
@@ -462,7 +479,7 @@ class SmallInteger(Integer):
"""
- __visit_name__ = 'small_integer'
+ __visit_name__ = "small_integer"
class BigInteger(Integer):
@@ -474,7 +491,7 @@ class BigInteger(Integer):
"""
- __visit_name__ = 'big_integer'
+ __visit_name__ = "big_integer"
class Numeric(_LookupExpressionAdapter, TypeEngine):
@@ -517,12 +534,17 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
"""
- __visit_name__ = 'numeric'
+ __visit_name__ = "numeric"
_default_decimal_return_scale = 10
- def __init__(self, precision=None, scale=None,
- decimal_return_scale=None, asdecimal=True):
+ def __init__(
+ self,
+ precision=None,
+ scale=None,
+ decimal_return_scale=None,
+ asdecimal=True,
+ ):
"""
Construct a Numeric.
@@ -587,6 +609,7 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
def literal_processor(self, dialect):
def process(value):
return str(value)
+
return process
@property
@@ -608,19 +631,23 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
# we're a "numeric", DBAPI will give us Decimal directly
return None
else:
- util.warn('Dialect %s+%s does *not* support Decimal '
- 'objects natively, and SQLAlchemy must '
- 'convert from floating point - rounding '
- 'errors and other issues may occur. Please '
- 'consider storing Decimal numbers as strings '
- 'or integers on this platform for lossless '
- 'storage.' % (dialect.name, dialect.driver))
+ util.warn(
+ "Dialect %s+%s does *not* support Decimal "
+ "objects natively, and SQLAlchemy must "
+ "convert from floating point - rounding "
+ "errors and other issues may occur. Please "
+ "consider storing Decimal numbers as strings "
+ "or integers on this platform for lossless "
+ "storage." % (dialect.name, dialect.driver)
+ )
# we're a "numeric", DBAPI returns floats, convert.
return processors.to_decimal_processor_factory(
decimal.Decimal,
- self.scale if self.scale is not None
- else self._default_decimal_return_scale)
+ self.scale
+ if self.scale is not None
+ else self._default_decimal_return_scale,
+ )
else:
if dialect.supports_native_decimal:
return processors.to_float
@@ -635,22 +662,13 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
Numeric: self.__class__,
Integer: self.__class__,
},
- operators.div: {
- Numeric: self.__class__,
- Integer: self.__class__,
- },
+ operators.div: {Numeric: self.__class__, Integer: self.__class__},
operators.truediv: {
Numeric: self.__class__,
Integer: self.__class__,
},
- operators.add: {
- Numeric: self.__class__,
- Integer: self.__class__,
- },
- operators.sub: {
- Numeric: self.__class__,
- Integer: self.__class__,
- }
+ operators.add: {Numeric: self.__class__, Integer: self.__class__},
+ operators.sub: {Numeric: self.__class__, Integer: self.__class__},
}
@@ -675,12 +693,17 @@ class Float(Numeric):
"""
- __visit_name__ = 'float'
+ __visit_name__ = "float"
scale = None
- def __init__(self, precision=None, asdecimal=False,
- decimal_return_scale=None, **kwargs):
+ def __init__(
+ self,
+ precision=None,
+ asdecimal=False,
+ decimal_return_scale=None,
+ **kwargs
+ ):
r"""
Construct a Float.
@@ -713,14 +736,15 @@ class Float(Numeric):
self.asdecimal = asdecimal
self.decimal_return_scale = decimal_return_scale
if kwargs:
- util.warn_deprecated("Additional keyword arguments "
- "passed to Float ignored.")
+ util.warn_deprecated(
+ "Additional keyword arguments " "passed to Float ignored."
+ )
def result_processor(self, dialect, coltype):
if self.asdecimal:
return processors.to_decimal_processor_factory(
- decimal.Decimal,
- self._effective_decimal_return_scale)
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
elif dialect.supports_native_decimal:
return processors.to_float
else:
@@ -746,7 +770,7 @@ class DateTime(_LookupExpressionAdapter, TypeEngine):
"""
- __visit_name__ = 'datetime'
+ __visit_name__ = "datetime"
def __init__(self, timezone=False):
"""Construct a new :class:`.DateTime`.
@@ -777,13 +801,8 @@ class DateTime(_LookupExpressionAdapter, TypeEngine):
# static/functions-datetime.html.
return {
- operators.add: {
- Interval: self.__class__,
- },
- operators.sub: {
- Interval: self.__class__,
- DateTime: Interval,
- },
+ operators.add: {Interval: self.__class__},
+ operators.sub: {Interval: self.__class__, DateTime: Interval},
}
@@ -791,7 +810,7 @@ class Date(_LookupExpressionAdapter, TypeEngine):
"""A type for ``datetime.date()`` objects."""
- __visit_name__ = 'date'
+ __visit_name__ = "date"
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
@@ -814,12 +833,9 @@ class Date(_LookupExpressionAdapter, TypeEngine):
operators.sub: {
# date - integer = date
Integer: self.__class__,
-
# date - date = integer.
Date: Integer,
-
Interval: DateTime,
-
# date - datetime = interval,
# this one is not in the PG docs
# but works
@@ -832,7 +848,7 @@ class Time(_LookupExpressionAdapter, TypeEngine):
"""A type for ``datetime.time()`` objects."""
- __visit_name__ = 'time'
+ __visit_name__ = "time"
def __init__(self, timezone=False):
self.timezone = timezone
@@ -850,14 +866,8 @@ class Time(_LookupExpressionAdapter, TypeEngine):
# static/functions-datetime.html.
return {
- operators.add: {
- Date: DateTime,
- Interval: self.__class__
- },
- operators.sub: {
- Time: Interval,
- Interval: self.__class__,
- },
+ operators.add: {Date: DateTime, Interval: self.__class__},
+ operators.sub: {Time: Interval, Interval: self.__class__},
}
@@ -872,6 +882,7 @@ class _Binary(TypeEngine):
def process(value):
value = value.decode(dialect.encoding).replace("'", "''")
return "'%s'" % value
+
return process
@property
@@ -891,14 +902,17 @@ class _Binary(TypeEngine):
return DBAPIBinary(value)
else:
return None
+
return process
# Python 3 has native bytes() type
# both sqlite3 and pg8000 seem to return it,
# psycopg2 as of 2.5 returns 'memoryview'
if util.py2k:
+
def result_processor(self, dialect, coltype):
if util.jython:
+
def process(value):
if value is not None:
if isinstance(value, array.array):
@@ -906,15 +920,19 @@ class _Binary(TypeEngine):
return str(value)
else:
return None
+
else:
process = processors.to_str
return process
+
else:
+
def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
value = bytes(value)
return value
+
return process
def coerce_compared_value(self, op, value):
@@ -939,7 +957,7 @@ class LargeBinary(_Binary):
"""
- __visit_name__ = 'large_binary'
+ __visit_name__ = "large_binary"
def __init__(self, length=None):
"""
@@ -958,8 +976,9 @@ class Binary(LargeBinary):
"""Deprecated. Renamed to LargeBinary."""
def __init__(self, *arg, **kw):
- util.warn_deprecated('The Binary type has been renamed to '
- 'LargeBinary.')
+ util.warn_deprecated(
+ "The Binary type has been renamed to " "LargeBinary."
+ )
LargeBinary.__init__(self, *arg, **kw)
@@ -986,8 +1005,15 @@ class SchemaType(SchemaEventTarget):
"""
- def __init__(self, name=None, schema=None, metadata=None,
- inherit_schema=False, quote=None, _create_events=True):
+ def __init__(
+ self,
+ name=None,
+ schema=None,
+ metadata=None,
+ inherit_schema=False,
+ quote=None,
+ _create_events=True,
+ ):
if name is not None:
self.name = quoted_name(name, quote)
else:
@@ -1001,12 +1027,12 @@ class SchemaType(SchemaEventTarget):
event.listen(
self.metadata,
"before_create",
- util.portable_instancemethod(self._on_metadata_create)
+ util.portable_instancemethod(self._on_metadata_create),
)
event.listen(
self.metadata,
"after_drop",
- util.portable_instancemethod(self._on_metadata_drop)
+ util.portable_instancemethod(self._on_metadata_drop),
)
def _translate_schema(self, effective_schema, map_):
@@ -1018,7 +1044,7 @@ class SchemaType(SchemaEventTarget):
def _variant_mapping_for_set_table(self, column):
if isinstance(column.type, Variant):
variant_mapping = column.type.mapping.copy()
- variant_mapping['_default'] = column.type.impl
+ variant_mapping["_default"] = column.type.impl
else:
variant_mapping = None
return variant_mapping
@@ -1036,15 +1062,15 @@ class SchemaType(SchemaEventTarget):
table,
"before_create",
util.portable_instancemethod(
- self._on_table_create,
- {"variant_mapping": variant_mapping})
+ self._on_table_create, {"variant_mapping": variant_mapping}
+ ),
)
event.listen(
table,
"after_drop",
util.portable_instancemethod(
- self._on_table_drop,
- {"variant_mapping": variant_mapping})
+ self._on_table_drop, {"variant_mapping": variant_mapping}
+ ),
)
if self.metadata is None:
# TODO: what's the difference between self.metadata
@@ -1054,29 +1080,33 @@ class SchemaType(SchemaEventTarget):
"before_create",
util.portable_instancemethod(
self._on_metadata_create,
- {"variant_mapping": variant_mapping})
+ {"variant_mapping": variant_mapping},
+ ),
)
event.listen(
table.metadata,
"after_drop",
util.portable_instancemethod(
self._on_metadata_drop,
- {"variant_mapping": variant_mapping})
+ {"variant_mapping": variant_mapping},
+ ),
)
def copy(self, **kw):
return self.adapt(self.__class__, _create_events=True)
def adapt(self, impltype, **kw):
- schema = kw.pop('schema', self.schema)
- metadata = kw.pop('metadata', self.metadata)
- _create_events = kw.pop('_create_events', False)
- return impltype(name=self.name,
- schema=schema,
- inherit_schema=self.inherit_schema,
- metadata=metadata,
- _create_events=_create_events,
- **kw)
+ schema = kw.pop("schema", self.schema)
+ metadata = kw.pop("metadata", self.metadata)
+ _create_events = kw.pop("_create_events", False)
+ return impltype(
+ name=self.name,
+ schema=schema,
+ inherit_schema=self.inherit_schema,
+ metadata=metadata,
+ _create_events=_create_events,
+ **kw
+ )
@property
def bind(self):
@@ -1133,15 +1163,17 @@ class SchemaType(SchemaEventTarget):
t._on_metadata_drop(target, bind, **kw)
def _is_impl_for_variant(self, dialect, kw):
- variant_mapping = kw.pop('variant_mapping', None)
+ variant_mapping = kw.pop("variant_mapping", None)
if variant_mapping is None:
return True
- if dialect.name in variant_mapping and \
- variant_mapping[dialect.name] is self:
+ if (
+ dialect.name in variant_mapping
+ and variant_mapping[dialect.name] is self
+ ):
return True
elif dialect.name not in variant_mapping:
- return variant_mapping['_default'] is self
+ return variant_mapping["_default"] is self
class Enum(Emulated, String, SchemaType):
@@ -1220,7 +1252,8 @@ class Enum(Emulated, String, SchemaType):
:class:`.mysql.ENUM` - MySQL-specific type
"""
- __visit_name__ = 'enum'
+
+ __visit_name__ = "enum"
def __init__(self, *enums, **kw):
r"""Construct an enum.
@@ -1322,15 +1355,15 @@ class Enum(Emulated, String, SchemaType):
other arguments in kw to pass through.
"""
- self.native_enum = kw.pop('native_enum', True)
- self.create_constraint = kw.pop('create_constraint', True)
- self.values_callable = kw.pop('values_callable', None)
+ self.native_enum = kw.pop("native_enum", True)
+ self.create_constraint = kw.pop("create_constraint", True)
+ self.values_callable = kw.pop("values_callable", None)
values, objects = self._parse_into_values(enums, kw)
self._setup_for_values(values, objects, kw)
- convert_unicode = kw.pop('convert_unicode', None)
- self.validate_strings = kw.pop('validate_strings', False)
+ convert_unicode = kw.pop("convert_unicode", None)
+ self.validate_strings = kw.pop("validate_strings", False)
if convert_unicode is None:
for e in self.enums:
@@ -1347,33 +1380,35 @@ class Enum(Emulated, String, SchemaType):
self._valid_lookup[None] = self._object_lookup[None] = None
super(Enum, self).__init__(
- length=length,
- convert_unicode=convert_unicode,
+ length=length, convert_unicode=convert_unicode
)
if self.enum_class:
- kw.setdefault('name', self.enum_class.__name__.lower())
+ kw.setdefault("name", self.enum_class.__name__.lower())
SchemaType.__init__(
self,
- name=kw.pop('name', None),
- schema=kw.pop('schema', None),
- metadata=kw.pop('metadata', None),
- inherit_schema=kw.pop('inherit_schema', False),
- quote=kw.pop('quote', None),
- _create_events=kw.pop('_create_events', True)
+ name=kw.pop("name", None),
+ schema=kw.pop("schema", None),
+ metadata=kw.pop("metadata", None),
+ inherit_schema=kw.pop("inherit_schema", False),
+ quote=kw.pop("quote", None),
+ _create_events=kw.pop("_create_events", True),
)
def _parse_into_values(self, enums, kw):
- if not enums and '_enums' in kw:
- enums = kw.pop('_enums')
+ if not enums and "_enums" in kw:
+ enums = kw.pop("_enums")
- if len(enums) == 1 and hasattr(enums[0], '__members__'):
+ if len(enums) == 1 and hasattr(enums[0], "__members__"):
self.enum_class = enums[0]
if self.values_callable:
values = self.values_callable(self.enum_class)
else:
values = list(self.enum_class.__members__)
- objects = [self.enum_class.__members__[k] for k in self.enum_class.__members__]
+ objects = [
+ self.enum_class.__members__[k]
+ for k in self.enum_class.__members__
+ ]
return values, objects
else:
self.enum_class = None
@@ -1382,18 +1417,16 @@ class Enum(Emulated, String, SchemaType):
def _setup_for_values(self, values, objects, kw):
self.enums = list(values)
- self._valid_lookup = dict(
- zip(reversed(objects), reversed(values))
- )
+ self._valid_lookup = dict(zip(reversed(objects), reversed(values)))
- self._object_lookup = dict(
- zip(values, objects)
- )
+ self._object_lookup = dict(zip(values, objects))
- self._valid_lookup.update([
- (value, self._valid_lookup[self._object_lookup[value]])
- for value in values
- ])
+ self._valid_lookup.update(
+ [
+ (value, self._valid_lookup[self._object_lookup[value]])
+ for value in values
+ ]
+ )
@property
def native(self):
@@ -1411,22 +1444,24 @@ class Enum(Emulated, String, SchemaType):
# here between an INSERT statement and a criteria used in a SELECT,
# for now we're staying conservative w/ behavioral changes (perhaps
# someone has a trigger that handles strings on INSERT)
- if not self.validate_strings and \
- isinstance(elem, compat.string_types):
+ if not self.validate_strings and isinstance(
+ elem, compat.string_types
+ ):
return elem
else:
raise LookupError(
- '"%s" is not among the defined enum values' % elem)
+ '"%s" is not among the defined enum values' % elem
+ )
class Comparator(String.Comparator):
-
def _adapt_expression(self, op, other_comparator):
op, typ = super(Enum.Comparator, self)._adapt_expression(
- op, other_comparator)
+ op, other_comparator
+ )
if op is operators.concat_op:
typ = String(
- self.type.length,
- convert_unicode=self.type.convert_unicode)
+ self.type.length, convert_unicode=self.type.convert_unicode
+ )
return op, typ
comparator_factory = Comparator
@@ -1436,38 +1471,40 @@ class Enum(Emulated, String, SchemaType):
return self._object_lookup[elem]
except KeyError:
raise LookupError(
- '"%s" is not among the defined enum values' % elem)
+ '"%s" is not among the defined enum values' % elem
+ )
def __repr__(self):
return util.generic_repr(
self,
- additional_kw=[('native_enum', True)],
+ additional_kw=[("native_enum", True)],
to_inspect=[Enum, SchemaType],
)
def adapt_to_emulated(self, impltype, **kw):
kw.setdefault("convert_unicode", self.convert_unicode)
kw.setdefault("validate_strings", self.validate_strings)
- kw.setdefault('name', self.name)
- kw.setdefault('schema', self.schema)
- kw.setdefault('inherit_schema', self.inherit_schema)
- kw.setdefault('metadata', self.metadata)
- kw.setdefault('_create_events', False)
- kw.setdefault('native_enum', self.native_enum)
- kw.setdefault('values_callable', self.values_callable)
- kw.setdefault('create_constraint', self.create_constraint)
- assert '_enums' in kw
+ kw.setdefault("name", self.name)
+ kw.setdefault("schema", self.schema)
+ kw.setdefault("inherit_schema", self.inherit_schema)
+ kw.setdefault("metadata", self.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("native_enum", self.native_enum)
+ kw.setdefault("values_callable", self.values_callable)
+ kw.setdefault("create_constraint", self.create_constraint)
+ assert "_enums" in kw
return impltype(**kw)
def adapt(self, impltype, **kw):
- kw['_enums'] = self._enums_argument
+ kw["_enums"] = self._enums_argument
return super(Enum, self).adapt(impltype, **kw)
def _should_create_constraint(self, compiler, **kw):
if not self._is_impl_for_variant(compiler.dialect, kw):
return False
- return not self.native_enum or \
- not compiler.dialect.supports_native_enum
+ return (
+ not self.native_enum or not compiler.dialect.supports_native_enum
+ )
@util.dependencies("sqlalchemy.sql.schema")
def _set_table(self, schema, column, table):
@@ -1483,20 +1520,21 @@ class Enum(Emulated, String, SchemaType):
name=_defer_name(self.name),
_create_rule=util.portable_instancemethod(
self._should_create_constraint,
- {"variant_mapping": variant_mapping}),
- _type_bound=True
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
)
assert e.table is table
def literal_processor(self, dialect):
- parent_processor = super(
- Enum, self).literal_processor(dialect)
+ parent_processor = super(Enum, self).literal_processor(dialect)
def process(value):
value = self._db_value_for_elem(value)
if parent_processor:
value = parent_processor(value)
return value
+
return process
def bind_processor(self, dialect):
@@ -1510,8 +1548,7 @@ class Enum(Emulated, String, SchemaType):
return process
def result_processor(self, dialect, coltype):
- parent_processor = super(Enum, self).result_processor(
- dialect, coltype)
+ parent_processor = super(Enum, self).result_processor(dialect, coltype)
def process(value):
if parent_processor:
@@ -1548,8 +1585,9 @@ class PickleType(TypeDecorator):
impl = LargeBinary
- def __init__(self, protocol=pickle.HIGHEST_PROTOCOL,
- pickler=None, comparator=None):
+ def __init__(
+ self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, comparator=None
+ ):
"""
Construct a PickleType.
@@ -1570,40 +1608,46 @@ class PickleType(TypeDecorator):
super(PickleType, self).__init__()
def __reduce__(self):
- return PickleType, (self.protocol,
- None,
- self.comparator)
+ return PickleType, (self.protocol, None, self.comparator)
def bind_processor(self, dialect):
impl_processor = self.impl.bind_processor(dialect)
dumps = self.pickler.dumps
protocol = self.protocol
if impl_processor:
+
def process(value):
if value is not None:
value = dumps(value, protocol)
return impl_processor(value)
+
else:
+
def process(value):
if value is not None:
value = dumps(value, protocol)
return value
+
return process
def result_processor(self, dialect, coltype):
impl_processor = self.impl.result_processor(dialect, coltype)
loads = self.pickler.loads
if impl_processor:
+
def process(value):
value = impl_processor(value)
if value is None:
return None
return loads(value)
+
else:
+
def process(value):
if value is None:
return None
return loads(value)
+
return process
def compare_values(self, x, y):
@@ -1635,11 +1679,10 @@ class Boolean(Emulated, TypeEngine, SchemaType):
"""
- __visit_name__ = 'boolean'
+ __visit_name__ = "boolean"
native = True
- def __init__(
- self, create_constraint=True, name=None, _create_events=True):
+ def __init__(self, create_constraint=True, name=None, _create_events=True):
"""Construct a Boolean.
:param create_constraint: defaults to True. If the boolean
@@ -1657,8 +1700,10 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def _should_create_constraint(self, compiler, **kw):
if not self._is_impl_for_variant(compiler.dialect, kw):
return False
- return not compiler.dialect.supports_native_boolean and \
- compiler.dialect.non_native_boolean_check_constraint
+ return (
+ not compiler.dialect.supports_native_boolean
+ and compiler.dialect.non_native_boolean_check_constraint
+ )
@util.dependencies("sqlalchemy.sql.schema")
def _set_table(self, schema, column, table):
@@ -1672,8 +1717,9 @@ class Boolean(Emulated, TypeEngine, SchemaType):
name=_defer_name(self.name),
_create_rule=util.portable_instancemethod(
self._should_create_constraint,
- {"variant_mapping": variant_mapping}),
- _type_bound=True
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
)
assert e.table is table
@@ -1686,11 +1732,11 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def _strict_as_bool(self, value):
if value not in self._strict_bools:
if not isinstance(value, int):
- raise TypeError(
- "Not a boolean value: %r" % value)
+ raise TypeError("Not a boolean value: %r" % value)
else:
raise ValueError(
- "Value %r is not None, True, or False" % value)
+ "Value %r is not None, True, or False" % value
+ )
return value
def literal_processor(self, dialect):
@@ -1700,6 +1746,7 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def process(value):
return true if self._strict_as_bool(value) else false
+
return process
def bind_processor(self, dialect):
@@ -1714,6 +1761,7 @@ class Boolean(Emulated, TypeEngine, SchemaType):
if value is not None:
value = _coerce(value)
return value
+
return process
def result_processor(self, dialect, coltype):
@@ -1736,18 +1784,10 @@ class _AbstractInterval(_LookupExpressionAdapter, TypeEngine):
DateTime: DateTime,
Time: Time,
},
- operators.sub: {
- Interval: self.__class__
- },
- operators.mul: {
- Numeric: self.__class__
- },
- operators.truediv: {
- Numeric: self.__class__
- },
- operators.div: {
- Numeric: self.__class__
- }
+ operators.sub: {Interval: self.__class__},
+ operators.mul: {Numeric: self.__class__},
+ operators.truediv: {Numeric: self.__class__},
+ operators.div: {Numeric: self.__class__},
}
@property
@@ -1780,9 +1820,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
impl = DateTime
epoch = dt.datetime.utcfromtimestamp(0)
- def __init__(self, native=True,
- second_precision=None,
- day_precision=None):
+ def __init__(self, native=True, second_precision=None, day_precision=None):
"""Construct an Interval object.
:param native: when True, use the actual
@@ -1815,31 +1853,39 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
impl_processor = self.impl.bind_processor(dialect)
epoch = self.epoch
if impl_processor:
+
def process(value):
if value is not None:
value = epoch + value
return impl_processor(value)
+
else:
+
def process(value):
if value is not None:
value = epoch + value
return value
+
return process
def result_processor(self, dialect, coltype):
impl_processor = self.impl.result_processor(dialect, coltype)
epoch = self.epoch
if impl_processor:
+
def process(value):
value = impl_processor(value)
if value is None:
return None
return value - epoch
+
else:
+
def process(value):
if value is None:
return None
return value - epoch
+
return process
@@ -1986,10 +2032,11 @@ class JSON(Indexable, TypeEngine):
"""
- __visit_name__ = 'JSON'
+
+ __visit_name__ = "JSON"
hashable = False
- NULL = util.symbol('JSON_NULL')
+ NULL = util.symbol("JSON_NULL")
"""Describe the json value of NULL.
This value is used to force the JSON value of ``"null"`` to be
@@ -2109,20 +2156,25 @@ class JSON(Indexable, TypeEngine):
class Comparator(Indexable.Comparator, Concatenable.Comparator):
"""Define comparison operations for :class:`.types.JSON`."""
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def _setup_getitem(self, default_comparator, index):
- if not isinstance(index, util.string_types) and \
- isinstance(index, compat.collections_abc.Sequence):
+ if not isinstance(index, util.string_types) and isinstance(
+ index, compat.collections_abc.Sequence
+ ):
index = default_comparator._check_literal(
- self.expr, operators.json_path_getitem_op,
- index, bindparam_type=JSON.JSONPathType
+ self.expr,
+ operators.json_path_getitem_op,
+ index,
+ bindparam_type=JSON.JSONPathType,
)
operator = operators.json_path_getitem_op
else:
index = default_comparator._check_literal(
- self.expr, operators.json_getitem_op,
- index, bindparam_type=JSON.JSONIndexType
+ self.expr,
+ operators.json_getitem_op,
+ index,
+ bindparam_type=JSON.JSONIndexType,
)
operator = operators.json_getitem_op
@@ -2172,6 +2224,7 @@ class JSON(Indexable, TypeEngine):
if string_process:
value = string_process(value)
return json_deserializer(value)
+
return process
@@ -2266,7 +2319,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
:class:`.postgresql.ARRAY`
"""
- __visit_name__ = 'ARRAY'
+
+ __visit_name__ = "ARRAY"
zero_indexes = False
"""if True, Python zero-based indexes should be interpreted as one-based
@@ -2285,21 +2339,23 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
if isinstance(index, slice):
return_type = self.type
if self.type.zero_indexes:
- index = slice(
- index.start + 1,
- index.stop + 1,
- index.step
- )
+ index = slice(index.start + 1, index.stop + 1, index.step)
index = Slice(
_literal_as_binds(
- index.start, name=self.expr.key,
- type_=type_api.INTEGERTYPE),
+ index.start,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
_literal_as_binds(
- index.stop, name=self.expr.key,
- type_=type_api.INTEGERTYPE),
+ index.stop,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
_literal_as_binds(
- index.step, name=self.expr.key,
- type_=type_api.INTEGERTYPE)
+ index.step,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
)
else:
if self.type.zero_indexes:
@@ -2307,16 +2363,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
if self.type.dimensions is None or self.type.dimensions == 1:
return_type = self.type.item_type
else:
- adapt_kw = {'dimensions': self.type.dimensions - 1}
+ adapt_kw = {"dimensions": self.type.dimensions - 1}
return_type = self.type.adapt(
- self.type.__class__, **adapt_kw)
+ self.type.__class__, **adapt_kw
+ )
return operators.getitem, index, return_type
def contains(self, *arg, **kw):
raise NotImplementedError(
"ARRAY.contains() not implemented for the base "
- "ARRAY type; please use the dialect-specific ARRAY type")
+ "ARRAY type; please use the dialect-specific ARRAY type"
+ )
@util.dependencies("sqlalchemy.sql.elements")
def any(self, elements, other, operator=None):
@@ -2350,7 +2408,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
operator = operator if operator else operators.eq
return operator(
elements._literal_as_binds(other),
- elements.CollectionAggregate._create_any(self.expr)
+ elements.CollectionAggregate._create_any(self.expr),
)
@util.dependencies("sqlalchemy.sql.elements")
@@ -2385,13 +2443,14 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
operator = operator if operator else operators.eq
return operator(
elements._literal_as_binds(other),
- elements.CollectionAggregate._create_all(self.expr)
+ elements.CollectionAggregate._create_all(self.expr),
)
comparator_factory = Comparator
- def __init__(self, item_type, as_tuple=False, dimensions=None,
- zero_indexes=False):
+ def __init__(
+ self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+ ):
"""Construct an :class:`.types.ARRAY`.
E.g.::
@@ -2424,8 +2483,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
"""
if isinstance(item_type, ARRAY):
- raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
- "handles multi-dimensional arrays of basetype")
+ raise ValueError(
+ "Do not nest ARRAY types; ARRAY(basetype) "
+ "handles multi-dimensional arrays of basetype"
+ )
if isinstance(item_type, type):
item_type = item_type()
self.item_type = item_type
@@ -2463,35 +2524,37 @@ class REAL(Float):
"""The SQL REAL type."""
- __visit_name__ = 'REAL'
+ __visit_name__ = "REAL"
class FLOAT(Float):
"""The SQL FLOAT type."""
- __visit_name__ = 'FLOAT'
+ __visit_name__ = "FLOAT"
class NUMERIC(Numeric):
"""The SQL NUMERIC type."""
- __visit_name__ = 'NUMERIC'
+ __visit_name__ = "NUMERIC"
class DECIMAL(Numeric):
"""The SQL DECIMAL type."""
- __visit_name__ = 'DECIMAL'
+ __visit_name__ = "DECIMAL"
class INTEGER(Integer):
"""The SQL INT or INTEGER type."""
- __visit_name__ = 'INTEGER'
+ __visit_name__ = "INTEGER"
+
+
INT = INTEGER
@@ -2499,14 +2562,14 @@ class SMALLINT(SmallInteger):
"""The SQL SMALLINT type."""
- __visit_name__ = 'SMALLINT'
+ __visit_name__ = "SMALLINT"
class BIGINT(BigInteger):
"""The SQL BIGINT type."""
- __visit_name__ = 'BIGINT'
+ __visit_name__ = "BIGINT"
class TIMESTAMP(DateTime):
@@ -2520,7 +2583,7 @@ class TIMESTAMP(DateTime):
"""
- __visit_name__ = 'TIMESTAMP'
+ __visit_name__ = "TIMESTAMP"
def __init__(self, timezone=False):
"""Construct a new :class:`.TIMESTAMP`.
@@ -2543,28 +2606,28 @@ class DATETIME(DateTime):
"""The SQL DATETIME type."""
- __visit_name__ = 'DATETIME'
+ __visit_name__ = "DATETIME"
class DATE(Date):
"""The SQL DATE type."""
- __visit_name__ = 'DATE'
+ __visit_name__ = "DATE"
class TIME(Time):
"""The SQL TIME type."""
- __visit_name__ = 'TIME'
+ __visit_name__ = "TIME"
class TEXT(Text):
"""The SQL TEXT type."""
- __visit_name__ = 'TEXT'
+ __visit_name__ = "TEXT"
class CLOB(Text):
@@ -2574,63 +2637,63 @@ class CLOB(Text):
This type is found in Oracle and Informix.
"""
- __visit_name__ = 'CLOB'
+ __visit_name__ = "CLOB"
class VARCHAR(String):
"""The SQL VARCHAR type."""
- __visit_name__ = 'VARCHAR'
+ __visit_name__ = "VARCHAR"
class NVARCHAR(Unicode):
"""The SQL NVARCHAR type."""
- __visit_name__ = 'NVARCHAR'
+ __visit_name__ = "NVARCHAR"
class CHAR(String):
"""The SQL CHAR type."""
- __visit_name__ = 'CHAR'
+ __visit_name__ = "CHAR"
class NCHAR(Unicode):
"""The SQL NCHAR type."""
- __visit_name__ = 'NCHAR'
+ __visit_name__ = "NCHAR"
class BLOB(LargeBinary):
"""The SQL BLOB type."""
- __visit_name__ = 'BLOB'
+ __visit_name__ = "BLOB"
class BINARY(_Binary):
"""The SQL BINARY type."""
- __visit_name__ = 'BINARY'
+ __visit_name__ = "BINARY"
class VARBINARY(_Binary):
"""The SQL VARBINARY type."""
- __visit_name__ = 'VARBINARY'
+ __visit_name__ = "VARBINARY"
class BOOLEAN(Boolean):
"""The SQL BOOLEAN type."""
- __visit_name__ = 'BOOLEAN'
+ __visit_name__ = "BOOLEAN"
class NullType(TypeEngine):
@@ -2657,7 +2720,8 @@ class NullType(TypeEngine):
construct.
"""
- __visit_name__ = 'null'
+
+ __visit_name__ = "null"
_isnull = True
@@ -2666,16 +2730,18 @@ class NullType(TypeEngine):
def literal_processor(self, dialect):
def process(value):
return "NULL"
+
return process
class Comparator(TypeEngine.Comparator):
-
def _adapt_expression(self, op, other_comparator):
- if isinstance(other_comparator, NullType.Comparator) or \
- not operators.is_commutative(op):
+ if isinstance(
+ other_comparator, NullType.Comparator
+ ) or not operators.is_commutative(op):
return op, self.expr.type
else:
return other_comparator._adapt_expression(op, self)
+
comparator_factory = Comparator
@@ -2694,6 +2760,7 @@ class MatchType(Boolean):
"""
+
NULLTYPE = NullType()
BOOLEANTYPE = Boolean()
STRINGTYPE = String()
@@ -2709,7 +2776,7 @@ _type_map = {
dt.datetime: DateTime(),
dt.time: Time(),
dt.timedelta: Interval(),
- util.NoneType: NULLTYPE
+ util.NoneType: NULLTYPE,
}
if util.py3k:
@@ -2729,19 +2796,23 @@ def _resolve_value_to_type(value):
# objects.
insp = inspection.inspect(value, False)
if (
- insp is not None and
- # foil mock.Mock() and other impostors by ensuring
- # the inspection target itself self-inspects
- insp.__class__ in inspection._registrars
+ insp is not None
+ and
+ # foil mock.Mock() and other impostors by ensuring
+ # the inspection target itself self-inspects
+ insp.__class__ in inspection._registrars
):
raise exc.ArgumentError(
- "Object %r is not legal as a SQL literal value" % value)
+ "Object %r is not legal as a SQL literal value" % value
+ )
return NULLTYPE
else:
return _result_type
+
# back-assign to type_api
from . import type_api
+
type_api.BOOLEANTYPE = BOOLEANTYPE
type_api.STRINGTYPE = STRINGTYPE
type_api.INTEGERTYPE = INTEGERTYPE
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index a8dfa19be..7fe780783 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -49,7 +49,8 @@ class TypeEngine(Visitable):
"""
- __slots__ = 'expr', 'type'
+
+ __slots__ = "expr", "type"
default_comparator = None
@@ -57,16 +58,15 @@ class TypeEngine(Visitable):
self.expr = expr
self.type = expr.type
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def operate(self, default_comparator, op, *other, **kwargs):
o = default_comparator.operator_lookup[op.__name__]
return o[0](self.expr, op, *(other + o[1:]), **kwargs)
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def reverse_operate(self, default_comparator, op, other, **kwargs):
o = default_comparator.operator_lookup[op.__name__]
- return o[0](self.expr, op, other,
- reverse=True, *o[1:], **kwargs)
+ return o[0](self.expr, op, other, reverse=True, *o[1:], **kwargs)
def _adapt_expression(self, op, other_comparator):
"""evaluate the return type of <self> <op> <othertype>,
@@ -97,7 +97,7 @@ class TypeEngine(Visitable):
return op, self.type
def __reduce__(self):
- return _reconstitute_comparator, (self.expr, )
+ return _reconstitute_comparator, (self.expr,)
hashable = True
"""Flag, if False, means values from this type aren't hashable.
@@ -313,8 +313,10 @@ class TypeEngine(Visitable):
"""
- return self.__class__.column_expression.__code__ \
+ return (
+ self.__class__.column_expression.__code__
is not TypeEngine.column_expression.__code__
+ )
def bind_expression(self, bindvalue):
""""Given a bind value (i.e. a :class:`.BindParameter` instance),
@@ -351,8 +353,10 @@ class TypeEngine(Visitable):
"""
- return self.__class__.bind_expression.__code__ \
+ return (
+ self.__class__.bind_expression.__code__
is not TypeEngine.bind_expression.__code__
+ )
@staticmethod
def _to_instance(cls_or_self):
@@ -441,9 +445,9 @@ class TypeEngine(Visitable):
"""
try:
- return dialect._type_memos[self]['impl']
+ return dialect._type_memos[self]["impl"]
except KeyError:
- return self._dialect_info(dialect)['impl']
+ return self._dialect_info(dialect)["impl"]
def _unwrapped_dialect_impl(self, dialect):
"""Return the 'unwrapped' dialect impl for this type.
@@ -462,20 +466,20 @@ class TypeEngine(Visitable):
def _cached_literal_processor(self, dialect):
"""Return a dialect-specific literal processor for this type."""
try:
- return dialect._type_memos[self]['literal']
+ return dialect._type_memos[self]["literal"]
except KeyError:
d = self._dialect_info(dialect)
- d['literal'] = lp = d['impl'].literal_processor(dialect)
+ d["literal"] = lp = d["impl"].literal_processor(dialect)
return lp
def _cached_bind_processor(self, dialect):
"""Return a dialect-specific bind processor for this type."""
try:
- return dialect._type_memos[self]['bind']
+ return dialect._type_memos[self]["bind"]
except KeyError:
d = self._dialect_info(dialect)
- d['bind'] = bp = d['impl'].bind_processor(dialect)
+ d["bind"] = bp = d["impl"].bind_processor(dialect)
return bp
def _cached_result_processor(self, dialect, coltype):
@@ -488,7 +492,7 @@ class TypeEngine(Visitable):
# key assumption: DBAPI type codes are
# constants. Else this dictionary would
# grow unbounded.
- d[coltype] = rp = d['impl'].result_processor(dialect, coltype)
+ d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
return rp
def _cached_custom_processor(self, dialect, key, fn):
@@ -496,7 +500,7 @@ class TypeEngine(Visitable):
return dialect._type_memos[self][key]
except KeyError:
d = self._dialect_info(dialect)
- impl = d['impl']
+ impl = d["impl"]
d[key] = result = fn(impl)
return result
@@ -513,7 +517,7 @@ class TypeEngine(Visitable):
impl = self.adapt(type(self))
# this can't be self, else we create a cycle
assert impl is not self
- dialect._type_memos[self] = d = {'impl': impl}
+ dialect._type_memos[self] = d = {"impl": impl}
return d
def _gen_dialect_impl(self, dialect):
@@ -549,8 +553,10 @@ class TypeEngine(Visitable):
"""
_coerced_type = _resolve_value_to_type(value)
- if _coerced_type is NULLTYPE or _coerced_type._type_affinity \
- is self._type_affinity:
+ if (
+ _coerced_type is NULLTYPE
+ or _coerced_type._type_affinity is self._type_affinity
+ ):
return self
else:
return _coerced_type
@@ -586,8 +592,7 @@ class TypeEngine(Visitable):
def __str__(self):
if util.py2k:
- return unicode(self.compile()).\
- encode('ascii', 'backslashreplace')
+ return unicode(self.compile()).encode("ascii", "backslashreplace")
else:
return str(self.compile())
@@ -645,15 +650,16 @@ class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)):
``type_expression``, if it receives ``**kw`` in its signature.
"""
+
__visit_name__ = "user_defined"
- ensure_kwarg = 'get_col_spec'
+ ensure_kwarg = "get_col_spec"
class Comparator(TypeEngine.Comparator):
__slots__ = ()
def _adapt_expression(self, op, other_comparator):
- if hasattr(self.type, 'adapt_operator'):
+ if hasattr(self.type, "adapt_operator"):
util.warn_deprecated(
"UserDefinedType.adapt_operator is deprecated. Create "
"a UserDefinedType.Comparator subclass instead which "
@@ -854,6 +860,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
will cause the index value ``'foo'`` to be JSON encoded.
"""
+
__visit_name__ = "type_decorator"
def __init__(self, *args, **kwargs):
@@ -874,14 +881,16 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- if not hasattr(self.__class__, 'impl'):
- raise AssertionError("TypeDecorator implementations "
- "require a class-level variable "
- "'impl' which refers to the class of "
- "type being decorated")
+ if not hasattr(self.__class__, "impl"):
+ raise AssertionError(
+ "TypeDecorator implementations "
+ "require a class-level variable "
+ "'impl' which refers to the class of "
+ "type being decorated"
+ )
self.impl = to_instance(self.__class__.impl, *args, **kwargs)
- coerce_to_is_types = (util.NoneType, )
+ coerce_to_is_types = (util.NoneType,)
"""Specify those Python types which should be coerced at the expression
level to "IS <constant>" when compared using ``==`` (and same for
``IS NOT`` in conjunction with ``!=``.
@@ -906,24 +915,27 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
__slots__ = ()
def operate(self, op, *other, **kwargs):
- kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super(TypeDecorator.Comparator, self).operate(
- op, *other, **kwargs)
+ op, *other, **kwargs
+ )
def reverse_operate(self, op, other, **kwargs):
- kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super(TypeDecorator.Comparator, self).reverse_operate(
- op, other, **kwargs)
+ op, other, **kwargs
+ )
@property
def comparator_factory(self):
if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:
return self.impl.comparator_factory
else:
- return type("TDComparator",
- (TypeDecorator.Comparator,
- self.impl.comparator_factory),
- {})
+ return type(
+ "TDComparator",
+ (TypeDecorator.Comparator, self.impl.comparator_factory),
+ {},
+ )
def _gen_dialect_impl(self, dialect):
"""
@@ -939,10 +951,11 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
typedesc = self._unwrapped_dialect_impl(dialect)
tt = self.copy()
if not isinstance(tt, self.__class__):
- raise AssertionError('Type object %s does not properly '
- 'implement the copy() method, it must '
- 'return an object of type %s' %
- (self, self.__class__))
+ raise AssertionError(
+ "Type object %s does not properly "
+ "implement the copy() method, it must "
+ "return an object of type %s" % (self, self.__class__)
+ )
tt.impl = typedesc
return tt
@@ -1099,8 +1112,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- return self.__class__.process_bind_param.__code__ \
+ return (
+ self.__class__.process_bind_param.__code__
is not TypeDecorator.process_bind_param.__code__
+ )
@util.memoized_property
def _has_literal_processor(self):
@@ -1109,8 +1124,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- return self.__class__.process_literal_param.__code__ \
+ return (
+ self.__class__.process_literal_param.__code__
is not TypeDecorator.process_literal_param.__code__
+ )
def literal_processor(self, dialect):
"""Provide a literal processing function for the given
@@ -1147,9 +1164,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
if process_param:
impl_processor = self.impl.literal_processor(dialect)
if impl_processor:
+
def process(value):
return impl_processor(process_param(value, dialect))
+
else:
+
def process(value):
return process_param(value, dialect)
@@ -1180,10 +1200,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
process_param = self.process_bind_param
impl_processor = self.impl.bind_processor(dialect)
if impl_processor:
+
def process(value):
return impl_processor(process_param(value, dialect))
else:
+
def process(value):
return process_param(value, dialect)
@@ -1200,8 +1222,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
exception throw.
"""
- return self.__class__.process_result_value.__code__ \
+ return (
+ self.__class__.process_result_value.__code__
is not TypeDecorator.process_result_value.__code__
+ )
def result_processor(self, dialect, coltype):
"""Provide a result value processing function for the given
@@ -1225,13 +1249,14 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
if self._has_result_processor:
process_value = self.process_result_value
- impl_processor = self.impl.result_processor(dialect,
- coltype)
+ impl_processor = self.impl.result_processor(dialect, coltype)
if impl_processor:
+
def process(value):
return process_value(impl_processor(value), dialect)
else:
+
def process(value):
return process_value(value, dialect)
@@ -1397,7 +1422,8 @@ class Variant(TypeDecorator):
if dialect_name in self.mapping:
raise exc.ArgumentError(
"Dialect '%s' is already present in "
- "the mapping for this Variant" % dialect_name)
+ "the mapping for this Variant" % dialect_name
+ )
mapping = self.mapping.copy()
mapping[dialect_name] = type_
return Variant(self.impl, mapping)
@@ -1439,6 +1465,6 @@ def adapt_type(typeobj, colspecs):
# but it turns out the originally given "generic" type
# is actually a subclass of our resulting type, then we were already
# given a more specific type than that required; so use that.
- if (issubclass(typeobj.__class__, impltype)):
+ if issubclass(typeobj.__class__, impltype):
return typeobj
return typeobj.adapt(impltype)
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 12cfe09d1..4feaf9938 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -15,15 +15,29 @@ from . import operators, visitors
from itertools import chain
from collections import deque
-from .elements import BindParameter, ColumnClause, ColumnElement, \
- Null, UnaryExpression, literal_column, Label, _label_reference, \
- _textual_label_reference
-from .selectable import SelectBase, ScalarSelect, Join, FromClause, FromGrouping
+from .elements import (
+ BindParameter,
+ ColumnClause,
+ ColumnElement,
+ Null,
+ UnaryExpression,
+ literal_column,
+ Label,
+ _label_reference,
+ _textual_label_reference,
+)
+from .selectable import (
+ SelectBase,
+ ScalarSelect,
+ Join,
+ FromClause,
+ FromGrouping,
+)
from .schema import Column
join_condition = util.langhelpers.public_factory(
- Join._join_condition,
- ".sql.util.join_condition")
+ Join._join_condition, ".sql.util.join_condition"
+)
# names that are still being imported from the outside
from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate
@@ -88,8 +102,9 @@ def find_left_clause_that_matches_given(clauses, join_from):
for idx in liberal_idx:
f = clauses[idx]
for s in selectables:
- if set(surface_selectables(f)).\
- intersection(surface_selectables(s)):
+ if set(surface_selectables(f)).intersection(
+ surface_selectables(s)
+ ):
conservative_idx.append(idx)
break
if conservative_idx:
@@ -184,8 +199,9 @@ def visit_binary_product(fn, expr):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
- elif element.__visit_name__ == 'binary' and \
- operators.is_comparison(element.operator):
+ elif element.__visit_name__ == "binary" and operators.is_comparison(
+ element.operator
+ ):
stack.insert(0, element)
for l in visit(element.left):
for r in visit(element.right):
@@ -199,38 +215,47 @@ def visit_binary_product(fn, expr):
for elem in element.get_children():
for e in visit(elem):
yield e
+
list(visit(expr))
-def find_tables(clause, check_columns=False,
- include_aliases=False, include_joins=False,
- include_selects=False, include_crud=False):
+def find_tables(
+ clause,
+ check_columns=False,
+ include_aliases=False,
+ include_joins=False,
+ include_selects=False,
+ include_crud=False,
+):
"""locate Table objects within the given expression."""
tables = []
_visitors = {}
if include_selects:
- _visitors['select'] = _visitors['compound_select'] = tables.append
+ _visitors["select"] = _visitors["compound_select"] = tables.append
if include_joins:
- _visitors['join'] = tables.append
+ _visitors["join"] = tables.append
if include_aliases:
- _visitors['alias'] = tables.append
+ _visitors["alias"] = tables.append
if include_crud:
- _visitors['insert'] = _visitors['update'] = \
- _visitors['delete'] = lambda ent: tables.append(ent.table)
+ _visitors["insert"] = _visitors["update"] = _visitors[
+ "delete"
+ ] = lambda ent: tables.append(ent.table)
if check_columns:
+
def visit_column(column):
tables.append(column.table)
- _visitors['column'] = visit_column
- _visitors['table'] = tables.append
+ _visitors["column"] = visit_column
- visitors.traverse(clause, {'column_collections': False}, _visitors)
+ _visitors["table"] = tables.append
+
+ visitors.traverse(clause, {"column_collections": False}, _visitors)
return tables
@@ -243,10 +268,9 @@ def unwrap_order_by(clause):
stack = deque([clause])
while stack:
t = stack.popleft()
- if isinstance(t, ColumnElement) and \
- (
- not isinstance(t, UnaryExpression) or
- not operators.is_ordering_modifier(t.modifier)
+ if isinstance(t, ColumnElement) and (
+ not isinstance(t, UnaryExpression)
+ or not operators.is_ordering_modifier(t.modifier)
):
if isinstance(t, _label_reference):
t = t.element
@@ -266,9 +290,7 @@ def unwrap_label_reference(element):
if isinstance(elem, (_label_reference, _textual_label_reference)):
return elem.element
- return visitors.replacement_traverse(
- element, {}, replace
- )
+ return visitors.replacement_traverse(element, {}, replace)
def expand_column_list_from_order_by(collist, order_by):
@@ -278,17 +300,16 @@ def expand_column_list_from_order_by(collist, order_by):
in the collist.
"""
- cols_already_present = set([
- col.element if col._order_by_label_element is not None
- else col for col in collist
- ])
+ cols_already_present = set(
+ [
+ col.element if col._order_by_label_element is not None else col
+ for col in collist
+ ]
+ )
return [
- col for col in
- chain(*[
- unwrap_order_by(o)
- for o in order_by
- ])
+ col
+ for col in chain(*[unwrap_order_by(o) for o in order_by])
if col not in cols_already_present
]
@@ -325,9 +346,9 @@ def surface_column_elements(clause, include_scalar_selects=True):
be addressable in the WHERE clause of a SELECT if this element were
in the columns clause."""
- filter_ = (FromGrouping, )
+ filter_ = (FromGrouping,)
if not include_scalar_selects:
- filter_ += (SelectBase, )
+ filter_ += (SelectBase,)
stack = deque([clause])
while stack:
@@ -343,9 +364,7 @@ def selectables_overlap(left, right):
"""Return True if left/right have some overlapping selectable"""
return bool(
- set(surface_selectables(left)).intersection(
- surface_selectables(right)
- )
+ set(surface_selectables(left)).intersection(surface_selectables(right))
)
@@ -366,7 +385,7 @@ def bind_values(clause):
def visit_bindparam(bind):
v.append(bind.effective_value)
- visitors.traverse(clause, {}, {'bindparam': visit_bindparam})
+ visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
return v
@@ -383,7 +402,7 @@ class _repr_base(object):
_TUPLE = 1
_DICT = 2
- __slots__ = 'max_chars',
+ __slots__ = ("max_chars",)
def trunc(self, value):
rep = repr(value)
@@ -391,10 +410,12 @@ class _repr_base(object):
if lenrep > self.max_chars:
segment_length = self.max_chars // 2
rep = (
- rep[0:segment_length] +
- (" ... (%d characters truncated) ... "
- % (lenrep - self.max_chars)) +
- rep[-segment_length:]
+ rep[0:segment_length]
+ + (
+ " ... (%d characters truncated) ... "
+ % (lenrep - self.max_chars)
+ )
+ + rep[-segment_length:]
)
return rep
@@ -402,7 +423,7 @@ class _repr_base(object):
class _repr_row(_repr_base):
"""Provide a string view of a row."""
- __slots__ = 'row',
+ __slots__ = ("row",)
def __init__(self, row, max_chars=300):
self.row = row
@@ -412,7 +433,7 @@ class _repr_row(_repr_base):
trunc = self.trunc
return "(%s%s)" % (
", ".join(trunc(value) for value in self.row),
- "," if len(self.row) == 1 else ""
+ "," if len(self.row) == 1 else "",
)
@@ -424,7 +445,7 @@ class _repr_params(_repr_base):
"""
- __slots__ = 'params', 'batches',
+ __slots__ = "params", "batches"
def __init__(self, params, batches, max_chars=300):
self.params = params
@@ -435,11 +456,13 @@ class _repr_params(_repr_base):
if isinstance(self.params, list):
typ = self._LIST
ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
+ self.params[0], (list, dict, tuple)
+ )
elif isinstance(self.params, tuple):
typ = self._TUPLE
ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
+ self.params[0], (list, dict, tuple)
+ )
elif isinstance(self.params, dict):
typ = self._DICT
ismulti = False
@@ -448,11 +471,15 @@ class _repr_params(_repr_base):
if ismulti and len(self.params) > self.batches:
msg = " ... displaying %i of %i total bound parameter sets ... "
- return ' '.join((
- self._repr_multi(self.params[:self.batches - 2], typ)[0:-1],
- msg % (self.batches, len(self.params)),
- self._repr_multi(self.params[-2:], typ)[1:]
- ))
+ return " ".join(
+ (
+ self._repr_multi(self.params[: self.batches - 2], typ)[
+ 0:-1
+ ],
+ msg % (self.batches, len(self.params)),
+ self._repr_multi(self.params[-2:], typ)[1:],
+ )
+ )
elif ismulti:
return self._repr_multi(self.params, typ)
else:
@@ -467,12 +494,13 @@ class _repr_params(_repr_base):
elif isinstance(multi_params[0], dict):
elem_type = self._DICT
else:
- assert False, \
- "Unknown parameter type %s" % (type(multi_params[0]))
+ assert False, "Unknown parameter type %s" % (
+ type(multi_params[0])
+ )
elements = ", ".join(
- self._repr_params(params, elem_type)
- for params in multi_params)
+ self._repr_params(params, elem_type) for params in multi_params
+ )
else:
elements = ""
@@ -493,13 +521,10 @@ class _repr_params(_repr_base):
elif typ is self._TUPLE:
return "(%s%s)" % (
", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else ""
-
+ "," if len(params) == 1 else "",
)
else:
- return "[%s]" % (
- ", ".join(trunc(value) for value in params)
- )
+ return "[%s]" % (", ".join(trunc(value) for value in params))
def adapt_criterion_to_null(crit, nulls):
@@ -509,20 +534,24 @@ def adapt_criterion_to_null(crit, nulls):
"""
def visit_binary(binary):
- if isinstance(binary.left, BindParameter) \
- and binary.left._identifying_key in nulls:
+ if (
+ isinstance(binary.left, BindParameter)
+ and binary.left._identifying_key in nulls
+ ):
# reverse order if the NULL is on the left side
binary.left = binary.right
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
- elif isinstance(binary.right, BindParameter) \
- and binary.right._identifying_key in nulls:
+ elif (
+ isinstance(binary.right, BindParameter)
+ and binary.right._identifying_key in nulls
+ ):
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
- return visitors.cloned_traverse(crit, {}, {'binary': visit_binary})
+ return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
def splice_joins(left, right, stop_on=None):
@@ -570,8 +599,8 @@ def reduce_columns(columns, *clauses, **kw):
in the selectable to just those that are not repeated.
"""
- ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
- only_synonyms = kw.pop('only_synonyms', False)
+ ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
+ only_synonyms = kw.pop("only_synonyms", False)
columns = util.ordered_column_set(columns)
@@ -597,39 +626,48 @@ def reduce_columns(columns, *clauses, **kw):
continue
else:
raise
- if fk_col.shares_lineage(c) and \
- (not only_synonyms or
- c.name == col.name):
+ if fk_col.shares_lineage(c) and (
+ not only_synonyms or c.name == col.name
+ ):
omit.add(col)
break
if clauses:
+
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)]))
+ chain(*[c.proxy_set for c in columns.difference(omit)])
+ )
if binary.left in cols and binary.right in cols:
for c in reversed(columns):
- if c.shares_lineage(binary.right) and \
- (not only_synonyms or
- c.name == binary.left.name):
+ if c.shares_lineage(binary.right) and (
+ not only_synonyms or c.name == binary.left.name
+ ):
omit.add(c)
break
+
for clause in clauses:
if clause is not None:
- visitors.traverse(clause, {}, {'binary': visit_binary})
+ visitors.traverse(clause, {}, {"binary": visit_binary})
return ColumnSet(columns.difference(omit))
-def criterion_as_pairs(expression, consider_as_foreign_keys=None,
- consider_as_referenced_keys=None, any_operator=False):
+def criterion_as_pairs(
+ expression,
+ consider_as_foreign_keys=None,
+ consider_as_referenced_keys=None,
+ any_operator=False,
+):
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
- raise exc.ArgumentError("Can only specify one of "
- "'consider_as_foreign_keys' or "
- "'consider_as_referenced_keys'")
+ raise exc.ArgumentError(
+ "Can only specify one of "
+ "'consider_as_foreign_keys' or "
+ "'consider_as_referenced_keys'"
+ )
def col_is(a, b):
# return a is b
@@ -638,37 +676,44 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
def visit_binary(binary):
if not any_operator and binary.operator is not operators.eq:
return
- if not isinstance(binary.left, ColumnElement) or \
- not isinstance(binary.right, ColumnElement):
+ if not isinstance(binary.left, ColumnElement) or not isinstance(
+ binary.right, ColumnElement
+ ):
return
if consider_as_foreign_keys:
- if binary.left in consider_as_foreign_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_foreign_keys):
+ if binary.left in consider_as_foreign_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_foreign_keys
+ ):
pairs.append((binary.right, binary.left))
- elif binary.right in consider_as_foreign_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_foreign_keys):
+ elif binary.right in consider_as_foreign_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_foreign_keys
+ ):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
- if binary.left in consider_as_referenced_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_referenced_keys):
+ if binary.left in consider_as_referenced_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_referenced_keys
+ ):
pairs.append((binary.left, binary.right))
- elif binary.right in consider_as_referenced_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_referenced_keys):
+ elif binary.right in consider_as_referenced_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_referenced_keys
+ ):
pairs.append((binary.right, binary.left))
else:
- if isinstance(binary.left, Column) and \
- isinstance(binary.right, Column):
+ if isinstance(binary.left, Column) and isinstance(
+ binary.right, Column
+ ):
if binary.left.references(binary.right):
pairs.append((binary.right, binary.left))
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
+
pairs = []
- visitors.traverse(expression, {}, {'binary': visit_binary})
+ visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
@@ -699,28 +744,38 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""
- def __init__(self, selectable, equivalents=None,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False, anonymize_labels=False):
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ anonymize_labels=False,
+ ):
self.__traverse_options__ = {
- 'stop_on': [selectable],
- 'anonymize_labels': anonymize_labels}
+ "stop_on": [selectable],
+ "anonymize_labels": anonymize_labels,
+ }
self.selectable = selectable
self.include_fn = include_fn
self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
- def _corresponding_column(self, col, require_embedded,
- _seen=util.EMPTY_SET):
+ def _corresponding_column(
+ self, col, require_embedded, _seen=util.EMPTY_SET
+ ):
newcol = self.selectable.corresponding_column(
- col,
- require_embedded=require_embedded)
+ col, require_embedded=require_embedded
+ )
if newcol is None and col in self.equivalents and col not in _seen:
for equiv in self.equivalents[col]:
newcol = self._corresponding_column(
- equiv, require_embedded=require_embedded,
- _seen=_seen.union([col]))
+ equiv,
+ require_embedded=require_embedded,
+ _seen=_seen.union([col]),
+ )
if newcol is not None:
return newcol
if self.adapt_on_names and newcol is None:
@@ -728,8 +783,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
return newcol
def replace(self, col):
- if isinstance(col, FromClause) and \
- self.selectable.is_derived_from(col):
+ if isinstance(col, FromClause) and self.selectable.is_derived_from(
+ col
+ ):
return self.selectable
elif not isinstance(col, ColumnElement):
return None
@@ -772,16 +828,27 @@ class ColumnAdapter(ClauseAdapter):
"""
- def __init__(self, selectable, equivalents=None,
- chain_to=None, adapt_required=False,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False):
- ClauseAdapter.__init__(self, selectable, equivalents,
- include_fn=include_fn, exclude_fn=exclude_fn,
- adapt_on_names=adapt_on_names,
- anonymize_labels=anonymize_labels)
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ chain_to=None,
+ adapt_required=False,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ ):
+ ClauseAdapter.__init__(
+ self,
+ selectable,
+ equivalents,
+ include_fn=include_fn,
+ exclude_fn=exclude_fn,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=anonymize_labels,
+ )
if chain_to:
self.chain(chain_to)
@@ -800,9 +867,7 @@ class ColumnAdapter(ClauseAdapter):
def __getitem__(self, key):
if (
self.parent.include_fn and not self.parent.include_fn(key)
- ) or (
- self.parent.exclude_fn and self.parent.exclude_fn(key)
- ):
+ ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
if self.parent._wrap:
return self.parent._wrap.columns[key]
else:
@@ -843,7 +908,7 @@ class ColumnAdapter(ClauseAdapter):
def __getstate__(self):
d = self.__dict__.copy()
- del d['columns']
+ del d["columns"]
return d
def __setstate__(self, state):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index b39ec8167..bf1743643 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -29,11 +29,20 @@ from .. import util
import operator
from .. import exc
-__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
- 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
- 'iterate_depthfirst', 'traverse_using', 'traverse',
- 'traverse_depthfirst',
- 'cloned_traverse', 'replacement_traverse']
+__all__ = [
+ "VisitableType",
+ "Visitable",
+ "ClauseVisitor",
+ "CloningVisitor",
+ "ReplacingCloningVisitor",
+ "iterate",
+ "iterate_depthfirst",
+ "traverse_using",
+ "traverse",
+ "traverse_depthfirst",
+ "cloned_traverse",
+ "replacement_traverse",
+]
class VisitableType(type):
@@ -53,8 +62,7 @@ class VisitableType(type):
"""
def __init__(cls, clsname, bases, clsdict):
- if clsname != 'Visitable' and \
- hasattr(cls, '__visit_name__'):
+ if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
_generate_dispatch(cls)
super(VisitableType, cls).__init__(clsname, bases, clsdict)
@@ -64,7 +72,7 @@ def _generate_dispatch(cls):
"""Return an optimized visit dispatch function for the cls
for use by the compiler.
"""
- if '__visit_name__' in cls.__dict__:
+ if "__visit_name__" in cls.__dict__:
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
# There is an optimization opportunity here because the
@@ -79,12 +87,13 @@ def _generate_dispatch(cls):
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
+
else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
- visit_attr = 'visit_%s' % self.__visit_name__
+ visit_attr = "visit_%s" % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
@@ -92,8 +101,7 @@ def _generate_dispatch(cls):
else:
return meth(self, **kw)
- _compiler_dispatch.__doc__ = \
- """Look for an attribute named "visit_" + self.__visit_name__
+ _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
cls._compiler_dispatch = _compiler_dispatch
@@ -137,7 +145,7 @@ class ClauseVisitor(object):
visitors = {}
for name in dir(self):
- if name.startswith('visit_'):
+ if name.startswith("visit_"):
visitors[name[6:]] = getattr(self, name)
return visitors
@@ -148,7 +156,7 @@ class ClauseVisitor(object):
v = self
while v:
yield v
- v = getattr(v, '_next', None)
+ v = getattr(v, "_next", None)
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
@@ -178,7 +186,8 @@ class CloningVisitor(ClauseVisitor):
"""traverse and visit the given expression structure."""
return cloned_traverse(
- obj, self.__traverse_options__, self._visitor_dict)
+ obj, self.__traverse_options__, self._visitor_dict
+ )
class ReplacingCloningVisitor(CloningVisitor):
@@ -204,6 +213,7 @@ class ReplacingCloningVisitor(CloningVisitor):
e = v.replace(elem)
if e is not None:
return e
+
return replacement_traverse(obj, self.__traverse_options__, replace)
@@ -282,7 +292,7 @@ def cloned_traverse(obj, opts, visitors):
modifications by visitors."""
cloned = {}
- stop_on = set(opts.get('stop_on', []))
+ stop_on = set(opts.get("stop_on", []))
def clone(elem):
if elem in stop_on:
@@ -306,11 +316,13 @@ def replacement_traverse(obj, opts, replace):
replacement by a given replacement function."""
cloned = {}
- stop_on = {id(x) for x in opts.get('stop_on', [])}
+ stop_on = {id(x) for x in opts.get("stop_on", [])}
def clone(elem, **kw):
- if id(elem) in stop_on or \
- 'no_replacement_traverse' in elem._annotations:
+ if (
+ id(elem) in stop_on
+ or "no_replacement_traverse" in elem._annotations
+ ):
return elem
else:
newelem = replace(elem)
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 413a492b8..f46ca4528 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -10,23 +10,62 @@ from .warnings import assert_warnings
from . import config
-from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
- fails_on, fails_on_everything_except, skip, only_on, exclude, \
- against as _against, _server_version, only_if, fails
+from .exclusions import (
+ db_spec,
+ _is_excluded,
+ fails_if,
+ skip_if,
+ future,
+ fails_on,
+ fails_on_everything_except,
+ skip,
+ only_on,
+ exclude,
+ against as _against,
+ _server_version,
+ only_if,
+ fails,
+)
def against(*queries):
return _against(config._current, *queries)
-from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
- eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \
- assert_raises_message, AssertsCompiledSQL, ComparesTables, \
- AssertsExecutionResults, expect_deprecated, expect_warnings, \
- in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false
-from .util import run_as_contextmanager, rowset, fail, \
- provide_metadata, adict, force_drop_names, \
- teardown_events
+from .assertions import (
+ emits_warning,
+ emits_warning_on,
+ uses_deprecated,
+ eq_,
+ ne_,
+ le_,
+ is_,
+ is_not_,
+ startswith_,
+ assert_raises,
+ assert_raises_message,
+ AssertsCompiledSQL,
+ ComparesTables,
+ AssertsExecutionResults,
+ expect_deprecated,
+ expect_warnings,
+ in_,
+ not_in_,
+ eq_ignore_whitespace,
+ eq_regex,
+ is_true,
+ is_false,
+)
+
+from .util import (
+ run_as_contextmanager,
+ rowset,
+ fail,
+ provide_metadata,
+ adict,
+ force_drop_names,
+ teardown_events,
+)
crashes = skip
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index e42376921..73ab4556a 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -86,6 +86,7 @@ def emits_warning_on(db, *messages):
were in fact seen.
"""
+
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, assert_=False, *messages):
@@ -114,12 +115,14 @@ def uses_deprecated(*messages):
def decorate(fn, *args, **kw):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
+
return decorate
@contextlib.contextmanager
-def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
- py2konly=False):
+def _expect_warnings(
+ exc_cls, messages, regex=True, assert_=True, py2konly=False
+):
if regex:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
@@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
return
for filter_ in filters:
- if (regex and filter_.match(msg)) or \
- (not regex and filter_ == msg):
+ if (regex and filter_.match(msg)) or (
+ not regex and filter_ == msg
+ ):
seen.discard(filter_)
break
else:
@@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
yield
if assert_ and (not py2konly or not compat.py3k):
- assert not seen, "Warnings were not seen: %s" % \
- ", ".join("%r" % (s.pattern if regex else s) for s in seen)
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
def global_cleanup_assertions():
@@ -170,6 +175,7 @@ def global_cleanup_assertions():
"""
_assert_no_stray_pool_connections()
+
_STRAY_CONNECTION_FAILURES = 0
@@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections():
# OK, let's be somewhat forgiving.
_STRAY_CONNECTION_FAILURES += 1
- print("Encountered a stray connection in test cleanup: %s"
- % str(pool._refs))
+ print(
+ "Encountered a stray connection in test cleanup: %s"
+ % str(pool._refs)
+ )
# then do a real GC sweep. We shouldn't even be here
# so a single sweep should really be doing it, otherwise
# there's probably a real unreachable cycle somewhere.
@@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections():
pool._refs.clear()
_STRAY_CONNECTION_FAILURES = 0
warnings.warn(
- "Stray connection refused to leave "
- "after gc.collect(): %s" % err)
+ "Stray connection refused to leave " "after gc.collect(): %s" % err
+ )
elif _STRAY_CONNECTION_FAILURES > 10:
assert False, "Encountered more than 10 stray connections"
_STRAY_CONNECTION_FAILURES = 0
@@ -263,14 +271,16 @@ def not_in_(a, b, msg=None):
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
- a, fragment)
+ a,
+ fragment,
+ )
def eq_ignore_whitespace(a, b, msg=None):
- a = re.sub(r'^\s+?|\n', "", a)
- a = re.sub(r' {2,}', " ", a)
- b = re.sub(r'^\s+?|\n', "", b)
- b = re.sub(r' {2,}', " ", b)
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
assert a == b, msg or "%r != %r" % (a, b)
@@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls as e:
- assert re.search(
- msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
- print(util.text_type(e).encode('utf-8'))
+ assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
+ msg,
+ e,
+ )
+ print(util.text_type(e).encode("utf-8"))
+
class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None,
- checkparams=None, dialect=None,
- checkpositional=None,
- check_prefetch=None,
- use_default_dialect=False,
- allow_dialect_select=False,
- literal_binds=False,
- schema_translate_map=None):
+ def assert_compile(
+ self,
+ clause,
+ result,
+ params=None,
+ checkparams=None,
+ dialect=None,
+ checkpositional=None,
+ check_prefetch=None,
+ use_default_dialect=False,
+ allow_dialect_select=False,
+ literal_binds=False,
+ schema_translate_map=None,
+ ):
if use_default_dialect:
dialect = default.DefaultDialect()
elif allow_dialect_select:
dialect = None
else:
if dialect is None:
- dialect = getattr(self, '__dialect__', None)
+ dialect = getattr(self, "__dialect__", None)
if dialect is None:
dialect = config.db.dialect
- elif dialect == 'default':
+ elif dialect == "default":
dialect = default.DefaultDialect()
- elif dialect == 'default_enhanced':
+ elif dialect == "default_enhanced":
dialect = default.StrCompileDialect()
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
@@ -325,13 +344,13 @@ class AssertsCompiledSQL(object):
compile_kwargs = {}
if schema_translate_map:
- kw['schema_translate_map'] = schema_translate_map
+ kw["schema_translate_map"] = schema_translate_map
if params is not None:
- kw['column_keys'] = list(params)
+ kw["column_keys"] = list(params)
if literal_binds:
- compile_kwargs['literal_binds'] = True
+ compile_kwargs["literal_binds"] = True
if isinstance(clause, orm.Query):
context = clause._compile_context()
@@ -343,25 +362,27 @@ class AssertsCompiledSQL(object):
clause = stmt_mock.mock_calls[0][1][0]
if compile_kwargs:
- kw['compile_kwargs'] = compile_kwargs
+ kw["compile_kwargs"] = compile_kwargs
c = clause.compile(dialect=dialect, **kw)
- param_str = repr(getattr(c, 'params', {}))
+ param_str = repr(getattr(c, "params", {}))
if util.py3k:
- param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+ param_str = param_str.encode("utf-8").decode("ascii", "ignore")
print(
- ("\nSQL String:\n" +
- util.text_type(c) +
- param_str).encode('utf-8'))
+ ("\nSQL String:\n" + util.text_type(c) + param_str).encode(
+ "utf-8"
+ )
+ )
else:
print(
- "\nSQL String:\n" +
- util.text_type(c).encode('utf-8') +
- param_str)
+ "\nSQL String:\n"
+ + util.text_type(c).encode("utf-8")
+ + param_str
+ )
- cc = re.sub(r'[\n\t]', '', util.text_type(c))
+ cc = re.sub(r"[\n\t]", "", util.text_type(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
@@ -375,7 +396,6 @@ class AssertsCompiledSQL(object):
class ComparesTables(object):
-
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
@@ -386,8 +406,10 @@ class ComparesTables(object):
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
- assert isinstance(reflected_c.type, type(c.type)), \
- msg % (reflected_c.type, c.type)
+ assert isinstance(reflected_c.type, type(c.type)), msg % (
+ reflected_c.type,
+ c.type,
+ )
else:
self.assert_types_base(reflected_c, c)
@@ -396,20 +418,22 @@ class ComparesTables(object):
eq_(
{f.column.name for f in c.foreign_keys},
- {f.column.name for f in reflected_c.foreign_keys}
+ {f.column.name for f in reflected_c.foreign_keys},
)
if c.server_default:
- assert isinstance(reflected_c.server_default,
- schema.FetchedValue)
+ assert isinstance(
+ reflected_c.server_default, schema.FetchedValue
+ )
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
- assert c1.type._compare_type_affinity(c2.type),\
- "On column %r, type '%s' doesn't correspond to type '%s'" % \
- (c1.name, c1.type, c2.type)
+ assert c1.type._compare_type_affinity(c2.type), (
+ "On column %r, type '%s' doesn't correspond to type '%s'"
+ % (c1.name, c1.type, c2.type)
+ )
class AssertsExecutionResults(object):
@@ -419,15 +443,19 @@ class AssertsExecutionResults(object):
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list),
- "result list is not the same size as test list, " +
- "for class " + class_.__name__)
+ self.assert_(
+ len(result) == len(list),
+ "result list is not the same size as test list, "
+ + "for class "
+ + class_.__name__,
+ )
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_,
- "item class is not " + repr(class_))
+ self.assert_(
+ rowobj.__class__ is class_, "item class is not " + repr(class_)
+ )
for key, value in desc.items():
if isinstance(value, tuple):
if isinstance(value[1], list):
@@ -435,9 +463,11 @@ class AssertsExecutionResults(object):
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
- self.assert_(getattr(rowobj, key) == value,
- "attribute %s value %s does not match %s" % (
- key, getattr(rowobj, key), value))
+ self.assert_(
+ getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s"
+ % (key, getattr(rowobj, key), value),
+ )
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
@@ -453,14 +483,19 @@ class AssertsExecutionResults(object):
found = util.IdentitySet(result)
expected = {immutabledict(e) for e in expected}
- for wrong in util.itertools_filterfalse(lambda o:
- isinstance(o, cls), found):
- fail('Unexpected type "%s", expected "%s"' % (
- type(wrong).__name__, cls.__name__))
+ for wrong in util.itertools_filterfalse(
+ lambda o: isinstance(o, cls), found
+ ):
+ fail(
+ 'Unexpected type "%s", expected "%s"'
+ % (type(wrong).__name__, cls.__name__)
+ )
if len(found) != len(expected):
- fail('Unexpected object count "%s", expected "%s"' % (
- len(found), len(expected)))
+ fail(
+ 'Unexpected object count "%s", expected "%s"'
+ % (len(found), len(expected))
+ )
NOVALUE = object()
@@ -469,7 +504,8 @@ class AssertsExecutionResults(object):
if isinstance(value, tuple):
try:
self.assert_unordered_result(
- getattr(obj, key), value[0], *value[1])
+ getattr(obj, key), value[0], *value[1]
+ )
except AssertionError:
return False
else:
@@ -484,8 +520,9 @@ class AssertsExecutionResults(object):
break
else:
fail(
- "Expected %s instance with attributes %s not found." % (
- cls.__name__, repr(expected_item)))
+ "Expected %s instance with attributes %s not found."
+ % (cls.__name__, repr(expected_item))
+ )
return True
def sql_execution_asserter(self, db=None):
@@ -505,9 +542,9 @@ class AssertsExecutionResults(object):
newrules = []
for rule in rules:
if isinstance(rule, dict):
- newrule = assertsql.AllOf(*[
- assertsql.CompiledSQL(k, v) for k, v in rule.items()
- ])
+ newrule = assertsql.AllOf(
+ *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+ )
else:
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
@@ -516,7 +553,8 @@ class AssertsExecutionResults(object):
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
- db, callable_, assertsql.CountStatements(count))
+ db, callable_, assertsql.CountStatements(count)
+ )
def assert_multiple_sql_count(self, dbs, callable_, counts):
recs = [
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 7a525589d..d8e924cb6 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -26,8 +26,10 @@ class AssertRule(object):
pass
def no_more_statements(self):
- assert False, 'All statements are complete, but pending '\
- 'assertion rules remain'
+ assert False, (
+ "All statements are complete, but pending "
+ "assertion rules remain"
+ )
class SQLMatchRule(AssertRule):
@@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule):
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
- self.params is not None and self.params != stmt.parameters):
- self.errormessage = \
- "Testing for exact SQL %s parameters %s received %s %s" % (
- self.statement, self.params,
- stmt.statement, stmt.parameters
+ self.params is not None and self.params != stmt.parameters
+ ):
+ self.errormessage = (
+ "Testing for exact SQL %s parameters %s received %s %s"
+ % (
+ self.statement,
+ self.params,
+ stmt.statement,
+ stmt.parameters,
)
+ )
else:
execute_observed.statements.pop(0)
self.is_consumed = True
@@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
-
- def __init__(self, statement, params=None, dialect='default'):
+ def __init__(self, statement, params=None, dialect="default"):
self.statement = statement
self.params = params
self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
- if self.dialect == 'default':
+ if self.dialect == "default":
return DefaultDialect()
else:
# ugh
- if self.dialect == 'postgresql':
- params = {'implicit_returning': True}
+ if self.dialect == "postgresql":
+ params = {"implicit_returning": True}
else:
params = {}
return url.URL(self.dialect).get_dialect()(**params)
@@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule):
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = \
- context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
+ )
else:
- compiled = (
- context.compiled.statement.compile(
- dialect=compare_dialect,
- column_keys=context.compiled.column_keys,
- inline=context.compiled.inline,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ inline=context.compiled.inline,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
)
- _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
+ _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
parameters = execute_observed.parameters
if not parameters:
_received_parameters = [compiled.construct_params()]
else:
_received_parameters = [
- compiled.construct_params(m) for m in parameters]
+ compiled.construct_params(m) for m in parameters
+ ]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
- _received_statement, _received_parameters = \
- self._received_statement(execute_observed)
+ _received_statement, _received_parameters = self._received_statement(
+ execute_observed
+ )
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
@@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule):
for param_key in param:
# a key in param did not match current
# 'received'
- if param_key not in received or \
- received[param_key] != param[param_key]:
+ if (
+ param_key not in received
+ or received[param_key] != param[param_key]
+ ):
break
else:
# all keys in param matched 'received';
@@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule):
self.errormessage = None
else:
self.errormessage = self._failure_message(params) % {
- 'received_statement': _received_statement,
- 'received_parameters': _received_parameters
+ "received_statement": _received_statement,
+ "received_parameters": _received_parameters,
}
def _all_params(self, context):
@@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule):
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement %r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.statement.replace('%', '%%'), expected_params
- )
+ "Testing for compiled statement %r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (self.statement.replace("%", "%%"), expected_params)
)
@@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL):
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
- self.dialect = 'default'
+ self.dialect = "default"
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement ~%r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.orig_regex, expected_params
- )
+ "Testing for compiled statement ~%r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r" % (self.orig_regex, expected_params)
)
def _compare_sql(self, execute_observed, received_statement):
@@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL):
return execute_observed.context.dialect
def _compare_no_space(self, real_stmt, received_stmt):
- stmt = re.sub(r'[\n\t]', '', real_stmt)
+ stmt = re.sub(r"[\n\t]", "", real_stmt)
return received_stmt == stmt
def _received_statement(self, execute_observed):
- received_stmt, received_params = super(DialectSQL, self).\
- _received_statement(execute_observed)
+ received_stmt, received_params = super(
+ DialectSQL, self
+ )._received_statement(execute_observed)
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
@@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL):
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
- "statements actually invoked" % received_stmt)
+ "statements actually invoked" % received_stmt
+ )
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle
- if paramstyle == 'pyformat':
- stmt = re.sub(
- r':([\w_]+)', r"%(\1)s", stmt)
+ if paramstyle == "pyformat":
+ stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
else:
# positional params
repl = None
- if paramstyle == 'qmark':
+ if paramstyle == "qmark":
repl = "?"
- elif paramstyle == 'format':
+ elif paramstyle == "format":
repl = r"%s"
- elif paramstyle == 'numeric':
+ elif paramstyle == "numeric":
repl = None
- stmt = re.sub(r':([\w_]+)', repl, stmt)
+ stmt = re.sub(r":([\w_]+)", repl, stmt)
return received_statement == stmt
class CountStatements(AssertRule):
-
def __init__(self, count):
self.count = count
self._statement_count = 0
@@ -256,12 +264,13 @@ class CountStatements(AssertRule):
def no_more_statements(self):
if self.count != self._statement_count:
- assert False, 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
+ assert False, "desired statement count %d does not match %d" % (
+ self.count,
+ self._statement_count,
+ )
class AllOf(AssertRule):
-
def __init__(self, *rules):
self.rules = set(rules)
@@ -283,7 +292,6 @@ class AllOf(AssertRule):
class EachOf(AssertRule):
-
def __init__(self, *rules):
self.rules = list(rules)
@@ -309,7 +317,6 @@ class EachOf(AssertRule):
class Or(AllOf):
-
def process_statement(self, execute_observed):
for rule in self.rules:
rule.process_statement(execute_observed)
@@ -331,7 +338,8 @@ class SQLExecuteObserved(object):
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
- ["statement", "parameters", "context", "executemany"])
+ ["statement", "parameters", "context", "executemany"],
+ )
):
pass
@@ -374,21 +382,25 @@ def assert_engine(engine):
orig[:] = clauseelement, multiparams, params
@event.listens_for(engine, "after_cursor_execute")
- def cursor_execute(conn, cursor, statement, parameters,
- context, executemany):
+ def cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
- if asserter.accumulated and \
- asserter.accumulated[-1].context is context:
+ if (
+ asserter.accumulated
+ and asserter.accumulated[-1].context is context
+ ):
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
- statement, parameters, context, executemany)
+ statement, parameters, context, executemany
+ )
)
try:
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index e9cfb3de9..1ff282af5 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -64,8 +64,9 @@ class Config(object):
assert _current, "Can't push without a default Config set up"
cls.push(
Config(
- db, _current.db_opts, _current.options, _current.file_config),
- namespace
+ db, _current.db_opts, _current.options, _current.file_config
+ ),
+ namespace,
)
@classmethod
@@ -94,4 +95,3 @@ class Config(object):
def skip_test(msg):
raise _skip_test_exception(msg)
-
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index d17e30edf..074e3b338 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -16,7 +16,6 @@ import warnings
class ConnectionKiller(object):
-
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = weakref.WeakKeyDictionary()
@@ -39,8 +38,8 @@ class ConnectionKiller(object):
fn()
except Exception as e:
warnings.warn(
- "testing_reaper couldn't "
- "rollback/close connection: %s" % e)
+ "testing_reaper couldn't " "rollback/close connection: %s" % e
+ )
def rollback_all(self):
for rec in list(self.proxy_refs):
@@ -97,18 +96,19 @@ class ConnectionKiller(object):
if rec.is_valid:
assert False
+
testing_reaper = ConnectionKiller()
def drop_all_tables(metadata, bind):
testing_reaper.close_all()
- if hasattr(bind, 'close'):
+ if hasattr(bind, "close"):
bind.close()
if not config.db.dialect.supports_alter:
from . import assertions
- with assertions.expect_warnings(
- "Can't sort tables", assert_=False):
+
+ with assertions.expect_warnings("Can't sort tables", assert_=False):
metadata.drop_all(bind)
else:
metadata.drop_all(bind)
@@ -151,19 +151,20 @@ def close_open_connections(fn, *args, **kw):
def all_dialects(exclude=None):
import sqlalchemy.databases as d
+
for name in d.__all__:
# TEMPORARY
if exclude and name in exclude:
continue
mod = getattr(d, name, None)
if not mod:
- mod = getattr(__import__(
- 'sqlalchemy.databases.%s' % name).databases, name)
+ mod = getattr(
+ __import__("sqlalchemy.databases.%s" % name).databases, name
+ )
yield mod.dialect()
class ReconnectFixture(object):
-
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
@@ -191,8 +192,8 @@ class ReconnectFixture(object):
fn()
except Exception as e:
warnings.warn(
- "ReconnectFixture couldn't "
- "close connection: %s" % e)
+ "ReconnectFixture couldn't " "close connection: %s" % e
+ )
def shutdown(self, stop=False):
# TODO: this doesn't cover all cases
@@ -214,7 +215,7 @@ def reconnecting_engine(url=None, options=None):
dbapi = config.db.dialect.dbapi
if not options:
options = {}
- options['module'] = ReconnectFixture(dbapi)
+ options["module"] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
@@ -238,7 +239,7 @@ def testing_engine(url=None, options=None):
if not options:
use_reaper = True
else:
- use_reaper = options.pop('use_reaper', True)
+ use_reaper = options.pop("use_reaper", True)
url = url or config.db.url
@@ -253,15 +254,15 @@ def testing_engine(url=None, options=None):
default_opt.update(options)
engine = create_engine(url, **options)
- engine._has_events = True # enable event blocks, helps with profiling
+ engine._has_events = True # enable event blocks, helps with profiling
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
if use_reaper:
- event.listen(engine.pool, 'connect', testing_reaper.connect)
- event.listen(engine.pool, 'checkout', testing_reaper.checkout)
- event.listen(engine.pool, 'invalidate', testing_reaper.invalidate)
+ event.listen(engine.pool, "connect", testing_reaper.connect)
+ event.listen(engine.pool, "checkout", testing_reaper.checkout)
+ event.listen(engine.pool, "invalidate", testing_reaper.invalidate)
testing_reaper.add_engine(engine)
return engine
@@ -290,19 +291,17 @@ def mock_engine(dialect_name=None):
buffer.append(sql)
def assert_sql(stmts):
- recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+ recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
assert recv == stmts, recv
def print_sql():
d = engine.dialect
- return "\n".join(
- str(s.compile(dialect=d))
- for s in engine.mock
- )
-
- engine = create_engine(dialect_name + '://',
- strategy='mock', executor=executor)
- assert not hasattr(engine, 'mock')
+ return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
+
+ engine = create_engine(
+ dialect_name + "://", strategy="mock", executor=executor
+ )
+ assert not hasattr(engine, "mock")
engine.mock = buffer
engine.assert_sql = assert_sql
engine.print_sql = print_sql
@@ -358,14 +357,15 @@ class DBAPIProxyConnection(object):
return getattr(self.conn, key)
-def proxying_engine(conn_cls=DBAPIProxyConnection,
- cursor_cls=DBAPIProxyCursor):
+def proxying_engine(
+ conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
+):
"""Produce an engine that provides proxy hooks for
common methods.
"""
+
def mock_conn():
return conn_cls(config.db, cursor_cls)
- return testing_engine(options={'creator': mock_conn})
-
+ return testing_engine(options={"creator": mock_conn})
diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py
index b634735fc..42c42149c 100644
--- a/lib/sqlalchemy/testing/entities.py
+++ b/lib/sqlalchemy/testing/entities.py
@@ -12,7 +12,6 @@ _repr_stack = set()
class BasicEntity(object):
-
def __init__(self, **kw):
for key, value in kw.items():
setattr(self, key, value)
@@ -24,17 +23,22 @@ class BasicEntity(object):
try:
return "%s(%s)" % (
(self.__class__.__name__),
- ', '.join(["%s=%r" % (key, getattr(self, key))
- for key in sorted(self.__dict__.keys())
- if not key.startswith('_')]))
+ ", ".join(
+ [
+ "%s=%r" % (key, getattr(self, key))
+ for key in sorted(self.__dict__.keys())
+ if not key.startswith("_")
+ ]
+ ),
+ )
finally:
_repr_stack.remove(id(self))
+
_recursion_stack = set()
class ComparableEntity(BasicEntity):
-
def __hash__(self):
return hash(self.__class__)
@@ -75,7 +79,7 @@ class ComparableEntity(BasicEntity):
b = other
for attr in list(a.__dict__):
- if attr.startswith('_'):
+ if attr.startswith("_"):
continue
value = getattr(a, attr)
@@ -85,9 +89,10 @@ class ComparableEntity(BasicEntity):
except (AttributeError, sa_exc.UnboundExecutionError):
return False
- if hasattr(value, '__iter__'):
- if hasattr(value, '__getitem__') and not hasattr(
- value, 'keys'):
+ if hasattr(value, "__iter__"):
+ if hasattr(value, "__getitem__") and not hasattr(
+ value, "keys"
+ ):
if list(value) != list(battr):
return False
else:
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
index 512fffb3b..9ed9e42c3 100644
--- a/lib/sqlalchemy/testing/exclusions.py
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -16,6 +16,7 @@ from . import config
from .. import util
from ..util import decorator
+
def skip_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
@@ -70,15 +71,15 @@ class compound(object):
def matching_config_reasons(self, config):
return [
- predicate._as_string(config) for predicate
- in self.skips.union(self.fails)
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
if predicate(config)
]
def include_test(self, include_tags, exclude_tags):
return bool(
- not self.tags.intersection(exclude_tags) and
- (not include_tags or self.tags.intersection(include_tags))
+ not self.tags.intersection(exclude_tags)
+ and (not include_tags or self.tags.intersection(include_tags))
)
def _extend(self, other):
@@ -87,13 +88,14 @@ class compound(object):
self.tags.update(other.tags)
def __call__(self, fn):
- if hasattr(fn, '_sa_exclusion_extend'):
+ if hasattr(fn, "_sa_exclusion_extend"):
fn._sa_exclusion_extend._extend(self)
return fn
@decorator
def decorate(fn, *args, **kw):
return self._do(config._current, fn, *args, **kw)
+
decorated = decorate(fn)
decorated._sa_exclusion_extend = self
return decorated
@@ -113,10 +115,7 @@ class compound(object):
def _do(self, cfg, fn, *args, **kw):
for skip in self.skips:
if skip(cfg):
- msg = "'%s' : %s" % (
- fn.__name__,
- skip._as_string(cfg)
- )
+ msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg))
config.skip_test(msg)
try:
@@ -127,16 +126,20 @@ class compound(object):
self._expect_success(cfg, name=fn.__name__)
return return_value
- def _expect_failure(self, config, ex, name='block'):
+ def _expect_failure(self, config, ex, name="block"):
for fail in self.fails:
if fail(config):
- print(("%s failed as expected (%s): %s " % (
- name, fail._as_string(config), str(ex))))
+ print(
+ (
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), str(ex))
+ )
+ )
break
else:
util.raise_from_cause(ex)
- def _expect_success(self, config, name='block'):
+ def _expect_success(self, config, name="block"):
if not self.fails:
return
for fail in self.fails:
@@ -144,13 +147,12 @@ class compound(object):
break
else:
raise AssertionError(
- "Unexpected success for '%s' (%s)" %
- (
+ "Unexpected success for '%s' (%s)"
+ % (
name,
" and ".join(
- fail._as_string(config)
- for fail in self.fails
- )
+ fail._as_string(config) for fail in self.fails
+ ),
)
)
@@ -186,21 +188,24 @@ class Predicate(object):
return predicate
elif isinstance(predicate, (list, set)):
return OrPredicate(
- [cls.as_predicate(pred) for pred in predicate],
- description)
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, util.string_types):
tokens = re.match(
- r'([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?', predicate)
+ r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
+ )
if not tokens:
raise ValueError(
- "Couldn't locate DB name in predicate: %r" % predicate)
+ "Couldn't locate DB name in predicate: %r" % predicate
+ )
db = tokens.group(1)
op = tokens.group(2)
spec = (
tuple(int(d) for d in tokens.group(3).split("."))
- if tokens.group(3) else None
+ if tokens.group(3)
+ else None
)
return SpecPredicate(db, op, spec, description=description)
@@ -215,11 +220,13 @@ class Predicate(object):
bool_ = not negate
return self.description % {
"driver": config.db.url.get_driver_name()
- if config else "<no driver>",
+ if config
+ else "<no driver>",
"database": config.db.url.get_backend_name()
- if config else "<no database>",
+ if config
+ else "<no database>",
"doesnt_support": "doesn't support" if bool_ else "does support",
- "does_support": "does support" if bool_ else "doesn't support"
+ "does_support": "does support" if bool_ else "doesn't support",
}
def _as_string(self, config=None, negate=False):
@@ -246,21 +253,21 @@ class SpecPredicate(Predicate):
self.description = description
_ops = {
- '<': operator.lt,
- '>': operator.gt,
- '==': operator.eq,
- '!=': operator.ne,
- '<=': operator.le,
- '>=': operator.ge,
- 'in': operator.contains,
- 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
}
def __call__(self, config):
engine = config.db
if "+" in self.db:
- dialect, driver = self.db.split('+')
+ dialect, driver = self.db.split("+")
else:
dialect, driver = self.db, None
@@ -273,8 +280,9 @@ class SpecPredicate(Predicate):
assert driver is None, "DBAPI version specs not supported yet"
version = _server_version(engine)
- oper = hasattr(self.op, '__call__') and self.op \
- or self._ops[self.op]
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
return oper(version, self.spec)
else:
return True
@@ -289,17 +297,9 @@ class SpecPredicate(Predicate):
return "%s" % self.db
else:
if negate:
- return "not %s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "not %s %s %s" % (self.db, self.op, self.spec)
else:
- return "%s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "%s %s %s" % (self.db, self.op, self.spec)
class LambdaPredicate(Predicate):
@@ -356,8 +356,9 @@ class OrPredicate(Predicate):
conjunction = " and "
else:
conjunction = " or "
- return conjunction.join(p._as_string(config, negate=negate)
- for p in self.predicates)
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
def _negation_str(self, config):
if self.description is not None:
@@ -387,7 +388,7 @@ def _server_version(engine):
# force metadata to be retrieved
conn = engine.connect()
- version = getattr(engine.dialect, 'server_version_info', None)
+ version = getattr(engine.dialect, "server_version_info", None)
if version is None:
version = ()
conn.close()
@@ -395,9 +396,7 @@ def _server_version(engine):
def db_spec(*dbs):
- return OrPredicate(
- [Predicate.as_predicate(db) for db in dbs]
- )
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
def open():
@@ -422,11 +421,7 @@ def fails_on(db, reason=None):
def fails_on_everything_except(*dbs):
- return succeeds_if(
- OrPredicate([
- Predicate.as_predicate(db) for db in dbs
- ])
- )
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
def skip(db, reason=None):
@@ -435,8 +430,9 @@ def skip(db, reason=None):
def only_on(dbs, reason=None):
return only_if(
- OrPredicate([Predicate.as_predicate(db, reason)
- for db in util.to_list(dbs)])
+ OrPredicate(
+ [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
+ )
)
@@ -446,7 +442,6 @@ def exclude(db, op, spec, reason=None):
def against(config, *queries):
assert queries, "no queries sent!"
- return OrPredicate([
- Predicate.as_predicate(query)
- for query in queries
- ])(config)
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index dd0fa5a48..98184cdd4 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -54,19 +54,19 @@ class TestBase(object):
class TablesTest(TestBase):
# 'once', None
- run_setup_bind = 'once'
+ run_setup_bind = "once"
# 'once', 'each', None
- run_define_tables = 'once'
+ run_define_tables = "once"
# 'once', 'each', None
- run_create_tables = 'once'
+ run_create_tables = "once"
# 'once', 'each', None
- run_inserts = 'each'
+ run_inserts = "each"
# 'each', None
- run_deletes = 'each'
+ run_deletes = "each"
# 'once', None
run_dispose_bind = None
@@ -86,10 +86,10 @@ class TablesTest(TestBase):
@classmethod
def _init_class(cls):
- if cls.run_define_tables == 'each':
- if cls.run_create_tables == 'once':
- cls.run_create_tables = 'each'
- assert cls.run_inserts in ('each', None)
+ if cls.run_define_tables == "each":
+ if cls.run_create_tables == "once":
+ cls.run_create_tables = "each"
+ assert cls.run_inserts in ("each", None)
cls.other = adict()
cls.tables = adict()
@@ -100,40 +100,40 @@ class TablesTest(TestBase):
@classmethod
def _setup_once_inserts(cls):
- if cls.run_inserts == 'once':
+ if cls.run_inserts == "once":
cls._load_fixtures()
cls.insert_data()
@classmethod
def _setup_once_tables(cls):
- if cls.run_define_tables == 'once':
+ if cls.run_define_tables == "once":
cls.define_tables(cls.metadata)
- if cls.run_create_tables == 'once':
+ if cls.run_create_tables == "once":
cls.metadata.create_all(cls.bind)
cls.tables.update(cls.metadata.tables)
def _setup_each_tables(self):
- if self.run_define_tables == 'each':
+ if self.run_define_tables == "each":
self.tables.clear()
- if self.run_create_tables == 'each':
+ if self.run_create_tables == "each":
drop_all_tables(self.metadata, self.bind)
self.metadata.clear()
self.define_tables(self.metadata)
- if self.run_create_tables == 'each':
+ if self.run_create_tables == "each":
self.metadata.create_all(self.bind)
self.tables.update(self.metadata.tables)
- elif self.run_create_tables == 'each':
+ elif self.run_create_tables == "each":
drop_all_tables(self.metadata, self.bind)
self.metadata.create_all(self.bind)
def _setup_each_inserts(self):
- if self.run_inserts == 'each':
+ if self.run_inserts == "each":
self._load_fixtures()
self.insert_data()
def _teardown_each_tables(self):
# no need to run deletes if tables are recreated on setup
- if self.run_define_tables != 'each' and self.run_deletes == 'each':
+ if self.run_define_tables != "each" and self.run_deletes == "each":
with self.bind.connect() as conn:
for table in reversed(self.metadata.sorted_tables):
try:
@@ -141,7 +141,8 @@ class TablesTest(TestBase):
except sa.exc.DBAPIError as ex:
util.print_(
("Error emptying table %s: %r" % (table, ex)),
- file=sys.stderr)
+ file=sys.stderr,
+ )
def setup(self):
self._setup_each_tables()
@@ -155,7 +156,7 @@ class TablesTest(TestBase):
if cls.run_create_tables:
drop_all_tables(cls.metadata, cls.bind)
- if cls.run_dispose_bind == 'once':
+ if cls.run_dispose_bind == "once":
cls.dispose_bind(cls.bind)
cls.metadata.bind = None
@@ -173,9 +174,9 @@ class TablesTest(TestBase):
@classmethod
def dispose_bind(cls, bind):
- if hasattr(bind, 'dispose'):
+ if hasattr(bind, "dispose"):
bind.dispose()
- elif hasattr(bind, 'close'):
+ elif hasattr(bind, "close"):
bind.close()
@classmethod
@@ -212,8 +213,12 @@ class TablesTest(TestBase):
continue
cls.bind.execute(
table.insert(),
- [dict(zip(headers[table], column_values))
- for column_values in rows[table]])
+ [
+ dict(zip(headers[table], column_values))
+ for column_values in rows[table]
+ ],
+ )
+
from sqlalchemy import event
@@ -236,7 +241,6 @@ class RemovesEvents(object):
class _ORMTest(object):
-
@classmethod
def teardown_class(cls):
sa.orm.session.Session.close_all()
@@ -249,10 +253,10 @@ class ORMTest(_ORMTest, TestBase):
class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
- run_setup_classes = 'once'
+ run_setup_classes = "once"
# 'once', 'each', None
- run_setup_mappers = 'each'
+ run_setup_mappers = "each"
classes = None
@@ -292,20 +296,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
@classmethod
def _setup_once_classes(cls):
- if cls.run_setup_classes == 'once':
+ if cls.run_setup_classes == "once":
cls._with_register_classes(cls.setup_classes)
@classmethod
def _setup_once_mappers(cls):
- if cls.run_setup_mappers == 'once':
+ if cls.run_setup_mappers == "once":
cls._with_register_classes(cls.setup_mappers)
def _setup_each_mappers(self):
- if self.run_setup_mappers == 'each':
+ if self.run_setup_mappers == "each":
self._with_register_classes(self.setup_mappers)
def _setup_each_classes(self):
- if self.run_setup_classes == 'each':
+ if self.run_setup_classes == "each":
self._with_register_classes(self.setup_classes)
@classmethod
@@ -339,11 +343,11 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# some tests create mappers in the test bodies
# and will define setup_mappers as None -
# clear mappers in any case
- if self.run_setup_mappers != 'once':
+ if self.run_setup_mappers != "once":
sa.orm.clear_mappers()
def _teardown_each_classes(self):
- if self.run_setup_classes != 'once':
+ if self.run_setup_classes != "once":
self.classes.clear()
@classmethod
@@ -356,8 +360,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
class DeclarativeMappedTest(MappedTest):
- run_setup_classes = 'once'
- run_setup_mappers = 'once'
+ run_setup_classes = "once"
+ run_setup_mappers = "once"
@classmethod
def _setup_once_tables(cls):
@@ -370,15 +374,16 @@ class DeclarativeMappedTest(MappedTest):
class FindFixtureDeclarative(DeclarativeMeta):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
- return DeclarativeMeta.__init__(
- cls, classname, bases, dict_)
+ return DeclarativeMeta.__init__(cls, classname, bases, dict_)
class DeclarativeBasic(object):
__table_cls__ = schema.Table
- _DeclBase = declarative_base(metadata=cls.metadata,
- metaclass=FindFixtureDeclarative,
- cls=DeclarativeBasic)
+ _DeclBase = declarative_base(
+ metadata=cls.metadata,
+ metaclass=FindFixtureDeclarative,
+ cls=DeclarativeBasic,
+ )
cls.DeclarativeBasic = _DeclBase
fn()
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
index ea0a8da82..dc530af5e 100644
--- a/lib/sqlalchemy/testing/mock.py
+++ b/lib/sqlalchemy/testing/mock.py
@@ -18,4 +18,5 @@ else:
except ImportError:
raise ImportError(
"SQLAlchemy's test suite requires the "
- "'mock' library as of 0.8.2.")
+ "'mock' library as of 0.8.2."
+ )
diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py
index 087fc1fe6..e84cbde44 100644
--- a/lib/sqlalchemy/testing/pickleable.py
+++ b/lib/sqlalchemy/testing/pickleable.py
@@ -46,29 +46,28 @@ class Parent(fixtures.ComparableEntity):
class Screen(object):
-
def __init__(self, obj, parent=None):
self.obj = obj
self.parent = parent
class Foo(object):
-
def __init__(self, moredata):
- self.data = 'im data'
- self.stuff = 'im stuff'
+ self.data = "im data"
+ self.stuff = "im stuff"
self.moredata = moredata
__hash__ = object.__hash__
def __eq__(self, other):
- return other.data == self.data and \
- other.stuff == self.stuff and \
- other.moredata == self.moredata
+ return (
+ other.data == self.data
+ and other.stuff == self.stuff
+ and other.moredata == self.moredata
+ )
class Bar(object):
-
def __init__(self, x, y):
self.x = x
self.y = y
@@ -76,35 +75,36 @@ class Bar(object):
__hash__ = object.__hash__
def __eq__(self, other):
- return other.__class__ is self.__class__ and \
- other.x == self.x and \
- other.y == self.y
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class OldSchool:
-
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
- return other.__class__ is self.__class__ and \
- other.x == self.x and \
- other.y == self.y
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
class OldSchoolWithoutCompare:
-
def __init__(self, x, y):
self.x = x
self.y = y
class BarWithoutCompare(object):
-
def __init__(self, x, y):
self.x = x
self.y = y
@@ -114,7 +114,6 @@ class BarWithoutCompare(object):
class NotComparable(object):
-
def __init__(self, data):
self.data = data
@@ -129,7 +128,6 @@ class NotComparable(object):
class BrokenComparable(object):
-
def __init__(self, data):
self.data = data
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
index 497fcb7e5..bb52c125c 100644
--- a/lib/sqlalchemy/testing/plugin/bootstrap.py
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0.
import os
import sys
-bootstrap_file = locals()['bootstrap_file']
-to_bootstrap = locals()['to_bootstrap']
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
def load_file_as_module(name):
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
if sys.version_info >= (3, 3):
from importlib import machinery
+
mod = machinery.SourceFileLoader(name, path).load_module()
else:
import imp
+
mod = imp.load_source(name, path)
return mod
+
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py
index 20ea61d89..0c28a5213 100644
--- a/lib/sqlalchemy/testing/plugin/noseplugin.py
+++ b/lib/sqlalchemy/testing/plugin/noseplugin.py
@@ -25,6 +25,7 @@ import sys
from nose.plugins import Plugin
import nose
+
fixtures = None
py3k = sys.version_info >= (3, 0)
@@ -33,7 +34,7 @@ py3k = sys.version_info >= (3, 0)
class NoseSQLAlchemy(Plugin):
enabled = True
- name = 'sqla_testing'
+ name = "sqla_testing"
score = 100
def options(self, parser, env=os.environ):
@@ -41,10 +42,14 @@ class NoseSQLAlchemy(Plugin):
opt = parser.add_option
def make_option(name, **kw):
- callback_ = kw.pop("callback", None) or kw.pop("zeroarg_callback", None)
+ callback_ = kw.pop("callback", None) or kw.pop(
+ "zeroarg_callback", None
+ )
if callback_:
+
def wrap_(option, opt_str, value, parser):
callback_(opt_str, value, parser)
+
kw["callback"] = wrap_
opt(name, **kw)
@@ -73,7 +78,7 @@ class NoseSQLAlchemy(Plugin):
def wantMethod(self, fn):
if py3k:
- if not hasattr(fn.__self__, 'cls'):
+ if not hasattr(fn.__self__, "cls"):
return False
cls = fn.__self__.cls
else:
@@ -84,24 +89,24 @@ class NoseSQLAlchemy(Plugin):
return plugin_base.want_class(cls)
def beforeTest(self, test):
- if not hasattr(test.test, 'cls'):
+ if not hasattr(test.test, "cls"):
return
plugin_base.before_test(
test,
test.test.cls.__module__,
- test.test.cls, test.test.method.__name__)
+ test.test.cls,
+ test.test.method.__name__,
+ )
def afterTest(self, test):
plugin_base.after_test(test)
def startContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.start_test_class(ctx)
def stopContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.stop_test_class(ctx)
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 0ffcae093..5d6bf2975 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -46,58 +46,130 @@ options = None
def setup_options(make_option):
- make_option("--log-info", action="callback", type="string", callback=_log,
- help="turn on info logging for <LOG> (multiple OK)")
- make_option("--log-debug", action="callback",
- type="string", callback=_log,
- help="turn on debug logging for <LOG> (multiple OK)")
- make_option("--db", action="append", type="string", dest="db",
- help="Use prefab database uri. Multiple OK, "
- "first one is run by default.")
- make_option('--dbs', action='callback', zeroarg_callback=_list_dbs,
- help="List available prefab dbs")
- make_option("--dburi", action="append", type="string", dest="dburi",
- help="Database uri. Multiple OK, "
- "first one is run by default.")
- make_option("--dropfirst", action="store_true", dest="dropfirst",
- help="Drop all tables in the target database first")
- make_option("--backend-only", action="store_true", dest="backend_only",
- help="Run only tests marked with __backend__")
- make_option("--nomemory", action="store_true", dest="nomemory",
- help="Don't run memory profiling tests")
- make_option("--postgresql-templatedb", type="string",
- help="name of template database to use for PostgreSQL "
- "CREATE DATABASE (defaults to current database)")
- make_option("--low-connections", action="store_true",
- dest="low_connections",
- help="Use a low number of distinct connections - "
- "i.e. for Oracle TNS")
- make_option("--write-idents", type="string", dest="write_idents",
- help="write out generated follower idents to <file>, "
- "when -n<num> is used")
- make_option("--reversetop", action="store_true",
- dest="reversetop", default=False,
- help="Use a random-ordering set implementation in the ORM "
- "(helps reveal dependency issues)")
- make_option("--requirements", action="callback", type="string",
- callback=_requirements_opt,
- help="requirements class for testing, overrides setup.cfg")
- make_option("--with-cdecimal", action="store_true",
- dest="cdecimal", default=False,
- help="Monkeypatch the cdecimal library into Python 'decimal' "
- "for all tests")
- make_option("--include-tag", action="callback", callback=_include_tag,
- type="string",
- help="Include tests with tag <tag>")
- make_option("--exclude-tag", action="callback", callback=_exclude_tag,
- type="string",
- help="Exclude tests with tag <tag>")
- make_option("--write-profiles", action="store_true",
- dest="write_profiles", default=False,
- help="Write/update failing profiling data.")
- make_option("--force-write-profiles", action="store_true",
- dest="force_write_profiles", default=False,
- help="Unconditionally write/update profiling data.")
+ make_option(
+ "--log-info",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--log-debug",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--db",
+ action="append",
+ type="string",
+ dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.",
+ )
+ make_option(
+ "--dbs",
+ action="callback",
+ zeroarg_callback=_list_dbs,
+ help="List available prefab dbs",
+ )
+ make_option(
+ "--dburi",
+ action="append",
+ type="string",
+ dest="dburi",
+ help="Database uri. Multiple OK, " "first one is run by default.",
+ )
+ make_option(
+ "--dropfirst",
+ action="store_true",
+ dest="dropfirst",
+ help="Drop all tables in the target database first",
+ )
+ make_option(
+ "--backend-only",
+ action="store_true",
+ dest="backend_only",
+ help="Run only tests marked with __backend__",
+ )
+ make_option(
+ "--nomemory",
+ action="store_true",
+ dest="nomemory",
+ help="Don't run memory profiling tests",
+ )
+ make_option(
+ "--postgresql-templatedb",
+ type="string",
+ help="name of template database to use for PostgreSQL "
+ "CREATE DATABASE (defaults to current database)",
+ )
+ make_option(
+ "--low-connections",
+ action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS",
+ )
+ make_option(
+ "--write-idents",
+ type="string",
+ dest="write_idents",
+ help="write out generated follower idents to <file>, "
+ "when -n<num> is used",
+ )
+ make_option(
+ "--reversetop",
+ action="store_true",
+ dest="reversetop",
+ default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)",
+ )
+ make_option(
+ "--requirements",
+ action="callback",
+ type="string",
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg",
+ )
+ make_option(
+ "--with-cdecimal",
+ action="store_true",
+ dest="cdecimal",
+ default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests",
+ )
+ make_option(
+ "--include-tag",
+ action="callback",
+ callback=_include_tag,
+ type="string",
+ help="Include tests with tag <tag>",
+ )
+ make_option(
+ "--exclude-tag",
+ action="callback",
+ callback=_exclude_tag,
+ type="string",
+ help="Exclude tests with tag <tag>",
+ )
+ make_option(
+ "--write-profiles",
+ action="store_true",
+ dest="write_profiles",
+ default=False,
+ help="Write/update failing profiling data.",
+ )
+ make_option(
+ "--force-write-profiles",
+ action="store_true",
+ dest="force_write_profiles",
+ default=False,
+ help="Unconditionally write/update profiling data.",
+ )
def configure_follower(follower_ident):
@@ -108,6 +180,7 @@ def configure_follower(follower_ident):
"""
from sqlalchemy.testing import provision
+
provision.FOLLOWER_IDENT = follower_ident
@@ -121,9 +194,9 @@ def memoize_important_follower_config(dict_):
callables, so we have to just copy all of that over.
"""
- dict_['memoized_config'] = {
- 'include_tags': include_tags,
- 'exclude_tags': exclude_tags
+ dict_["memoized_config"] = {
+ "include_tags": include_tags,
+ "exclude_tags": exclude_tags,
}
@@ -134,14 +207,14 @@ def restore_important_follower_config(dict_):
"""
global include_tags, exclude_tags
- include_tags.update(dict_['memoized_config']['include_tags'])
- exclude_tags.update(dict_['memoized_config']['exclude_tags'])
+ include_tags.update(dict_["memoized_config"]["include_tags"])
+ exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
def read_config():
global file_config
file_config = configparser.ConfigParser()
- file_config.read(['setup.cfg', 'test.cfg'])
+ file_config.read(["setup.cfg", "test.cfg"])
def pre_begin(opt):
@@ -155,6 +228,7 @@ def pre_begin(opt):
def set_coverage_flag(value):
options.has_coverage = value
+
_skip_test_exception = None
@@ -171,34 +245,33 @@ def post_begin():
# late imports, has to happen after config as well
# as nose plugins like coverage
- global util, fixtures, engines, exclusions, \
- assertions, warnings, profiling,\
- config, testing
- from sqlalchemy import testing # noqa
+ global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing
+ from sqlalchemy import testing # noqa
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
- from sqlalchemy.testing import assertions, warnings, profiling # noqa
+ from sqlalchemy.testing import assertions, warnings, profiling # noqa
from sqlalchemy.testing import config # noqa
from sqlalchemy import util # noqa
- warnings.setup_filters()
+ warnings.setup_filters()
def _log(opt_str, value, parser):
global logging
if not logging:
import logging
+
logging.basicConfig()
- if opt_str.endswith('-info'):
+ if opt_str.endswith("-info"):
logging.getLogger(value).setLevel(logging.INFO)
- elif opt_str.endswith('-debug'):
+ elif opt_str.endswith("-debug"):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
print("Available --db options (use --dburi to override)")
- for macro in sorted(file_config.options('db')):
- print("%20s\t%s" % (macro, file_config.get('db', macro)))
+ for macro in sorted(file_config.options("db")):
+ print("%20s\t%s" % (macro, file_config.get("db", macro)))
sys.exit(0)
@@ -207,11 +280,12 @@ def _requirements_opt(opt_str, value, parser):
def _exclude_tag(opt_str, value, parser):
- exclude_tags.add(value.replace('-', '_'))
+ exclude_tags.add(value.replace("-", "_"))
def _include_tag(opt_str, value, parser):
- include_tags.add(value.replace('-', '_'))
+ include_tags.add(value.replace("-", "_"))
+
pre_configure = []
post_configure = []
@@ -243,7 +317,8 @@ def _set_nomemory(opt, file_config):
def _monkeypatch_cdecimal(options, file_config):
if options.cdecimal:
import cdecimal
- sys.modules['decimal'] = cdecimal
+
+ sys.modules["decimal"] = cdecimal
@post
@@ -266,27 +341,28 @@ def _engine_uri(options, file_config):
if options.db:
for db_token in options.db:
- for db in re.split(r'[,\s]+', db_token):
- if db not in file_config.options('db'):
+ for db in re.split(r"[,\s]+", db_token):
+ if db not in file_config.options("db"):
raise RuntimeError(
"Unknown URI specifier '%s'. "
- "Specify --dbs for known uris."
- % db)
+ "Specify --dbs for known uris." % db
+ )
else:
- db_urls.append(file_config.get('db', db))
+ db_urls.append(file_config.get("db", db))
if not db_urls:
- db_urls.append(file_config.get('db', 'default'))
+ db_urls.append(file_config.get("db", "default"))
config._current = None
for db_url in db_urls:
- if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
+ if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
with open(options.write_idents, "a") as file_:
file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
cfg = provision.setup_config(
- db_url, options, file_config, provision.FOLLOWER_IDENT)
+ db_url, options, file_config, provision.FOLLOWER_IDENT
+ )
if not config._current:
cfg.set_as_current(cfg, testing)
@@ -295,7 +371,7 @@ def _engine_uri(options, file_config):
@post
def _requirements(options, file_config):
- requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+ requirement_cls = file_config.get("sqla_testing", "requirement_cls")
_setup_requirements(requirement_cls)
@@ -334,22 +410,28 @@ def _prep_testing_database(options, file_config):
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData())
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(vname, schema.MetaData())
+ )
+ )
if config.requirements.schemas.enabled_for_config(cfg):
try:
- view_names = inspector.get_view_names(
- schema="test_schema")
+ view_names = inspector.get_view_names(schema="test_schema")
except NotImplementedError:
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData(),
- schema="test_schema")
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(
+ vname,
+ schema.MetaData(),
+ schema="test_schema",
+ )
+ )
+ )
util.drop_all_tables(e, inspector)
@@ -358,23 +440,29 @@ def _prep_testing_database(options, file_config):
if against(cfg, "postgresql"):
from sqlalchemy.dialects import postgresql
+
for enum in inspector.get_enums("*"):
- e.execute(postgresql.DropEnumType(
- postgresql.ENUM(
- name=enum['name'],
- schema=enum['schema'])))
+ e.execute(
+ postgresql.DropEnumType(
+ postgresql.ENUM(
+ name=enum["name"], schema=enum["schema"]
+ )
+ )
+ )
@post
def _reverse_topological(options, file_config):
if options.reversetop:
from sqlalchemy.orm.util import randomize_unitofwork
+
randomize_unitofwork()
@post
def _post_setup_options(opt, file_config):
from sqlalchemy.testing import config
+
config.options = options
config.file_config = file_config
@@ -382,17 +470,20 @@ def _post_setup_options(opt, file_config):
@post
def _setup_profiling(options, file_config):
from sqlalchemy.testing import profiling
+
profiling._profile_stats = profiling.ProfileStatsFile(
- file_config.get('sqla_testing', 'profile_file'))
+ file_config.get("sqla_testing", "profile_file")
+ )
def want_class(cls):
if not issubclass(cls, fixtures.TestBase):
return False
- elif cls.__name__.startswith('_'):
+ elif cls.__name__.startswith("_"):
return False
- elif config.options.backend_only and not getattr(cls, '__backend__',
- False):
+ elif config.options.backend_only and not getattr(
+ cls, "__backend__", False
+ ):
return False
else:
return True
@@ -405,25 +496,28 @@ def want_method(cls, fn):
return False
elif include_tags:
return (
- hasattr(cls, '__tags__') and
- exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
+ hasattr(cls, "__tags__")
+ and exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
) or (
- hasattr(fn, '_sa_exclusion_extend') and
- fn._sa_exclusion_extend.include_test(
- include_tags, exclude_tags)
+ hasattr(fn, "_sa_exclusion_extend")
+ and fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags
+ )
)
- elif exclude_tags and hasattr(cls, '__tags__'):
+ elif exclude_tags and hasattr(cls, "__tags__"):
return exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
- elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
+ include_tags, exclude_tags
+ )
+ elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
else:
return True
def generate_sub_tests(cls, module):
- if getattr(cls, '__backend__', False):
+ if getattr(cls, "__backend__", False):
for cfg in _possible_configs_for_cls(cls):
orig_name = cls.__name__
@@ -431,16 +525,13 @@ def generate_sub_tests(cls, module):
# pytest junit plugin, which is tripped up by the brackets
# and periods, so sanitize
- alpha_name = re.sub('[_\[\]\.]+', '_', cfg.name)
- alpha_name = re.sub('_+$', '', alpha_name)
+ alpha_name = re.sub("[_\[\]\.]+", "_", cfg.name)
+ alpha_name = re.sub("_+$", "", alpha_name)
name = "%s_%s" % (cls.__name__, alpha_name)
subcls = type(
name,
- (cls, ),
- {
- "_sa_orig_cls_name": orig_name,
- "__only_on_config__": cfg
- }
+ (cls,),
+ {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
)
setattr(module, name, subcls)
yield subcls
@@ -454,8 +545,8 @@ def start_test_class(cls):
def stop_test_class(cls):
- #from sqlalchemy import inspect
- #assert not inspect(testing.db).get_table_names()
+ # from sqlalchemy import inspect
+ # assert not inspect(testing.db).get_table_names()
engines.testing_reaper._stop_test_ctx()
try:
if not options.low_connections:
@@ -475,7 +566,7 @@ def final_process_cleanup():
def _setup_engine(cls):
- if getattr(cls, '__engine_options__', None):
+ if getattr(cls, "__engine_options__", None):
eng = engines.testing_engine(options=cls.__engine_options__)
config._current.push_engine(eng, testing)
@@ -485,7 +576,7 @@ def before_test(test, test_module_name, test_class, test_name):
# like a nose id, e.g.:
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
- name = getattr(test_class, '_sa_orig_cls_name', test_class.__name__)
+ name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
@@ -505,16 +596,16 @@ def _possible_configs_for_cls(cls, reasons=None):
if spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on__', None):
+ if getattr(cls, "__only_on__", None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
for config_obj in list(all_configs):
if not spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on_config__', None):
+ if getattr(cls, "__only_on_config__", None):
all_configs.intersection_update([cls.__only_on_config__])
- if hasattr(cls, '__requires__'):
+ if hasattr(cls, "__requires__"):
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__requires__:
@@ -527,7 +618,7 @@ def _possible_configs_for_cls(cls, reasons=None):
reasons.extend(skip_reasons)
break
- if hasattr(cls, '__prefer_requires__'):
+ if hasattr(cls, "__prefer_requires__"):
non_preferred = set()
requirements = config.requirements
for config_obj in list(all_configs):
@@ -546,30 +637,32 @@ def _do_skips(cls):
reasons = []
all_configs = _possible_configs_for_cls(cls, reasons)
- if getattr(cls, '__skip_if__', False):
- for c in getattr(cls, '__skip_if__'):
+ if getattr(cls, "__skip_if__", False):
+ for c in getattr(cls, "__skip_if__"):
if c():
- config.skip_test("'%s' skipped by %s" % (
- cls.__name__, c.__name__)
+ config.skip_test(
+ "'%s' skipped by %s" % (cls.__name__, c.__name__)
)
if not all_configs:
msg = "'%s' unsupported on any DB implementation %s%s" % (
cls.__name__,
", ".join(
- "'%s(%s)+%s'" % (
+ "'%s(%s)+%s'"
+ % (
config_obj.db.name,
".".join(
- str(dig) for dig in
- exclusions._server_version(config_obj.db)),
- config_obj.db.driver
+ str(dig)
+ for dig in exclusions._server_version(config_obj.db)
+ ),
+ config_obj.db.driver,
)
- for config_obj in config.Config.all_configs()
+ for config_obj in config.Config.all_configs()
),
- ", ".join(reasons)
+ ", ".join(reasons),
)
config.skip_test(msg)
- elif hasattr(cls, '__prefer_backends__'):
+ elif hasattr(cls, "__prefer_backends__"):
non_preferred = set()
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
for config_obj in all_configs:
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index da682ea00..fd0a48462 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -13,6 +13,7 @@ import os
try:
import xdist # noqa
+
has_xdist = True
except ImportError:
has_xdist = False
@@ -24,30 +25,42 @@ def pytest_addoption(parser):
def make_option(name, **kw):
callback_ = kw.pop("callback", None)
if callback_:
+
class CallableAction(argparse.Action):
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
callback_(option_string, values, parser)
+
kw["action"] = CallableAction
zeroarg_callback = kw.pop("zeroarg_callback", None)
if zeroarg_callback:
+
class CallableAction(argparse.Action):
- def __init__(self, option_strings,
- dest, default=False,
- required=False, help=None):
- super(CallableAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=0,
- const=True,
- default=default,
- required=required,
- help=help)
-
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=False,
+ required=False,
+ help=None,
+ ):
+ super(CallableAction, self).__init__(
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ const=True,
+ default=default,
+ required=required,
+ help=help,
+ )
+
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
zeroarg_callback(option_string, values, parser)
+
kw["action"] = CallableAction
group.addoption(name, **kw)
@@ -59,18 +72,18 @@ def pytest_addoption(parser):
def pytest_configure(config):
if hasattr(config, "slaveinput"):
plugin_base.restore_important_follower_config(config.slaveinput)
- plugin_base.configure_follower(
- config.slaveinput["follower_ident"]
- )
+ plugin_base.configure_follower(config.slaveinput["follower_ident"])
else:
- if config.option.write_idents and \
- os.path.exists(config.option.write_idents):
+ if config.option.write_idents and os.path.exists(
+ config.option.write_idents
+ ):
os.remove(config.option.write_idents)
plugin_base.pre_begin(config.option)
- plugin_base.set_coverage_flag(bool(getattr(config.option,
- "cov_source", False)))
+ plugin_base.set_coverage_flag(
+ bool(getattr(config.option, "cov_source", False))
+ )
plugin_base.set_skip_test(pytest.skip.Exception)
@@ -94,10 +107,12 @@ if has_xdist:
node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
from sqlalchemy.testing import provision
+
provision.create_follower_db(node.slaveinput["follower_ident"])
def pytest_testnodedown(node, error):
from sqlalchemy.testing import provision
+
provision.drop_follower_db(node.slaveinput["follower_ident"])
@@ -114,19 +129,22 @@ def pytest_collection_modifyitems(session, config, items):
rebuilt_items = collections.defaultdict(list)
items[:] = [
- item for item in
- items if isinstance(item.parent, pytest.Instance)
- and not item.parent.parent.name.startswith("_")]
+ item
+ for item in items
+ if isinstance(item.parent, pytest.Instance)
+ and not item.parent.parent.name.startswith("_")
+ ]
test_classes = set(item.parent for item in items)
for test_class in test_classes:
for sub_cls in plugin_base.generate_sub_tests(
- test_class.cls, test_class.parent.module):
+ test_class.cls, test_class.parent.module
+ ):
if sub_cls is not test_class.cls:
list_ = rebuilt_items[test_class.cls]
for inst in pytest.Class(
- sub_cls.__name__,
- parent=test_class.parent.parent).collect():
+ sub_cls.__name__, parent=test_class.parent.parent
+ ).collect():
list_.extend(inst.collect())
newitems = []
@@ -139,23 +157,29 @@ def pytest_collection_modifyitems(session, config, items):
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
- items[:] = sorted(newitems, key=lambda item: (
- item.parent.parent.parent.name,
- item.parent.parent.name,
- item.name
- ))
+ items[:] = sorted(
+ newitems,
+ key=lambda item: (
+ item.parent.parent.parent.name,
+ item.parent.parent.name,
+ item.name,
+ ),
+ )
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(obj):
return pytest.Class(name, parent=collector)
- elif inspect.isfunction(obj) and \
- isinstance(collector, pytest.Instance) and \
- plugin_base.want_method(collector.cls, obj):
+ elif (
+ inspect.isfunction(obj)
+ and isinstance(collector, pytest.Instance)
+ and plugin_base.want_method(collector.cls, obj)
+ ):
return pytest.Function(name, parent=collector)
else:
return []
+
_current_class = None
@@ -180,6 +204,7 @@ def pytest_runtest_setup(item):
global _current_class
class_teardown(item.parent.parent)
_current_class = None
+
item.parent.parent.addfinalizer(finalize)
test_setup(item)
@@ -194,8 +219,9 @@ def pytest_runtest_teardown(item):
def test_setup(item):
- plugin_base.before_test(item, item.parent.module.__name__,
- item.parent.cls, item.name)
+ plugin_base.before_test(
+ item, item.parent.module.__name__, item.parent.cls, item.name
+ )
def test_teardown(item):
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
index fab99b186..3986985c7 100644
--- a/lib/sqlalchemy/testing/profiling.py
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -42,17 +42,16 @@ class ProfileStatsFile(object):
def __init__(self, filename):
self.force_write = (
- config.options is not None and
- config.options.force_write_profiles
+ config.options is not None and config.options.force_write_profiles
)
self.write = self.force_write or (
- config.options is not None and
- config.options.write_profiles
+ config.options is not None and config.options.write_profiles
)
self.fname = os.path.abspath(filename)
self.short_fname = os.path.split(self.fname)[-1]
self.data = collections.defaultdict(
- lambda: collections.defaultdict(dict))
+ lambda: collections.defaultdict(dict)
+ )
self._read()
if self.write:
# rewrite for the case where features changed,
@@ -65,7 +64,7 @@ class ProfileStatsFile(object):
dbapi_key = config.db.name + "_" + config.db.driver
# keep it at 2.7, 3.1, 3.2, etc. for now.
- py_version = '.'.join([str(v) for v in sys.version_info[0:2]])
+ py_version = ".".join([str(v) for v in sys.version_info[0:2]])
platform_tokens = [py_version]
platform_tokens.append(dbapi_key)
@@ -87,8 +86,7 @@ class ProfileStatsFile(object):
def has_stats(self):
test_key = _current_test
return (
- test_key in self.data and
- self.platform_key in self.data[test_key]
+ test_key in self.data and self.platform_key in self.data[test_key]
)
def result(self, callcount):
@@ -96,15 +94,15 @@ class ProfileStatsFile(object):
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
- if 'counts' not in per_platform:
- per_platform['counts'] = counts = []
+ if "counts" not in per_platform:
+ per_platform["counts"] = counts = []
else:
- counts = per_platform['counts']
+ counts = per_platform["counts"]
- if 'current_count' not in per_platform:
- per_platform['current_count'] = current_count = 0
+ if "current_count" not in per_platform:
+ per_platform["current_count"] = current_count = 0
else:
- current_count = per_platform['current_count']
+ current_count = per_platform["current_count"]
has_count = len(counts) > current_count
@@ -114,16 +112,16 @@ class ProfileStatsFile(object):
self._write()
result = None
else:
- result = per_platform['lineno'], counts[current_count]
- per_platform['current_count'] += 1
+ result = per_platform["lineno"], counts[current_count]
+ per_platform["current_count"] += 1
return result
def replace(self, callcount):
test_key = _current_test
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
- counts = per_platform['counts']
- current_count = per_platform['current_count']
+ counts = per_platform["counts"]
+ current_count = per_platform["current_count"]
if current_count < len(counts):
counts[current_count - 1] = callcount
else:
@@ -164,9 +162,9 @@ class ProfileStatsFile(object):
per_fn = self.data[test_key]
per_platform = per_fn[platform_key]
c = [int(count) for count in counts.split(",")]
- per_platform['counts'] = c
- per_platform['lineno'] = lineno + 1
- per_platform['current_count'] = 0
+ per_platform["counts"] = c
+ per_platform["lineno"] = lineno + 1
+ per_platform["current_count"] = 0
profile_f.close()
def _write(self):
@@ -179,7 +177,7 @@ class ProfileStatsFile(object):
profile_f.write("\n# TEST: %s\n\n" % test_key)
for platform_key in sorted(per_fn):
per_platform = per_fn[platform_key]
- c = ",".join(str(count) for count in per_platform['counts'])
+ c = ",".join(str(count) for count in per_platform["counts"])
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
profile_f.close()
@@ -199,7 +197,9 @@ def function_call_count(variance=0.05):
def wrap(*args, **kw):
with count_functions(variance=variance):
return fn(*args, **kw)
+
return update_wrapper(wrap, fn)
+
return decorate
@@ -213,21 +213,22 @@ def count_functions(variance=0.05):
"No profiling stats available on this "
"platform for this function. Run tests with "
"--write-profiles to add statistics to %s for "
- "this platform." % _profile_stats.short_fname)
+ "this platform." % _profile_stats.short_fname
+ )
gc_collect()
pr = cProfile.Profile()
pr.enable()
- #began = time.time()
+ # began = time.time()
yield
- #ended = time.time()
+ # ended = time.time()
pr.disable()
- #s = compat.StringIO()
+ # s = compat.StringIO()
stats = pstats.Stats(pr, stream=sys.stdout)
- #timespent = ended - began
+ # timespent = ended - began
callcount = stats.total_calls
expected = _profile_stats.result(callcount)
@@ -237,11 +238,7 @@ def count_functions(variance=0.05):
else:
line_no, expected_count = expected
- print(("Pstats calls: %d Expected %s" % (
- callcount,
- expected_count
- )
- ))
+ print(("Pstats calls: %d Expected %s" % (callcount, expected_count)))
stats.sort_stats("cumulative")
stats.print_stats()
@@ -259,7 +256,9 @@ def count_functions(variance=0.05):
"--write-profiles to "
"regenerate this callcount."
% (
- callcount, (variance * 100),
- expected_count, _profile_stats.platform_key))
-
-
+ callcount,
+ (variance * 100),
+ expected_count,
+ _profile_stats.platform_key,
+ )
+ )
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index c0ca7c1cb..25028ccb3 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -8,6 +8,7 @@ import collections
import os
import time
import logging
+
log = logging.getLogger(__name__)
FOLLOWER_IDENT = None
@@ -25,6 +26,7 @@ class register(object):
def decorate(fn):
self.fns[dbname] = fn
return self
+
return decorate
def __call__(self, cfg, *arg):
@@ -38,7 +40,7 @@ class register(object):
if backend in self.fns:
return self.fns[backend](cfg, *arg)
else:
- return self.fns['*'](cfg, *arg)
+ return self.fns["*"](cfg, *arg)
def create_follower_db(follower_ident):
@@ -82,9 +84,7 @@ def _configs_for_db_operation():
for cfg in config.Config.all_configs():
url = cfg.db.url
backend = url.get_backend_name()
- host_conf = (
- backend,
- url.username, url.host, url.database)
+ host_conf = (backend, url.username, url.host, url.database)
if host_conf not in hosts:
yield cfg
@@ -128,14 +128,13 @@ def _follower_url_from_main(url, ident):
@_update_db_opts.for_db("mssql")
def _mssql_update_db_opts(db_url, db_opts):
- db_opts['legacy_schema_aliasing'] = False
-
+ db_opts["legacy_schema_aliasing"] = False
@_follower_url_from_main.for_db("sqlite")
def _sqlite_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
- if not url.database or url.database == ':memory:':
+ if not url.database or url.database == ":memory:":
return url
else:
return sa_url.make_url("sqlite:///%s.db" % ident)
@@ -151,19 +150,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident):
# as an attached
if not follower_ident:
dbapi_connection.execute(
- 'ATTACH DATABASE "test_schema.db" AS test_schema')
+ 'ATTACH DATABASE "test_schema.db" AS test_schema'
+ )
else:
dbapi_connection.execute(
'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
- % follower_ident)
+ % follower_ident
+ )
@_create_db.for_db("postgresql")
def _pg_create_db(cfg, eng, ident):
template_db = cfg.options.postgresql_templatedb
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
try:
_pg_drop_db(cfg, conn, ident)
except Exception:
@@ -175,7 +175,8 @@ def _pg_create_db(cfg, eng, ident):
while True:
try:
conn.execute(
- "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db))
+ "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
+ )
except exc.OperationalError as err:
attempt += 1
if attempt >= 3:
@@ -184,8 +185,11 @@ def _pg_create_db(cfg, eng, ident):
log.info(
"Waiting to create %s, URI %r, "
"template DB %s is in use sleeping for .5",
- ident, eng.url, template_db)
- time.sleep(.5)
+ ident,
+ eng.url,
+ template_db,
+ )
+ time.sleep(0.5)
else:
break
@@ -203,9 +207,11 @@ def _mysql_create_db(cfg, eng, ident):
# 1271, u"Illegal mix of collations for operation 'UNION'"
conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb3" % ident)
conn.execute(
- "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident)
+ "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident
+ )
conn.execute(
- "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident)
+ "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident
+ )
@_configure_follower.for_db("mysql")
@@ -221,14 +227,15 @@ def _sqlite_create_db(cfg, eng, ident):
@_drop_db.for_db("postgresql")
def _pg_drop_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.execute(
text(
"select pg_terminate_backend(pid) from pg_stat_activity "
"where usename=current_user and pid != pg_backend_pid() "
"and datname=:dname"
- ), dname=ident)
+ ),
+ dname=ident,
+ )
conn.execute("DROP DATABASE %s" % ident)
@@ -257,11 +264,12 @@ def _oracle_create_db(cfg, eng, ident):
conn.execute("create user %s identified by xe" % ident)
conn.execute("create user %s_ts1 identified by xe" % ident)
conn.execute("create user %s_ts2 identified by xe" % ident)
- conn.execute("grant dba to %s" % (ident, ))
+ conn.execute("grant dba to %s" % (ident,))
conn.execute("grant unlimited tablespace to %s" % ident)
conn.execute("grant unlimited tablespace to %s_ts1" % ident)
conn.execute("grant unlimited tablespace to %s_ts2" % ident)
+
@_configure_follower.for_db("oracle")
def _oracle_configure_follower(config, ident):
config.test_schema = "%s_ts1" % ident
@@ -320,6 +328,7 @@ def reap_dbs(idents_file):
elif backend == "mssql":
_reap_mssql_dbs(url, ident)
+
def _reap_oracle_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
@@ -330,8 +339,9 @@ def _reap_oracle_dbs(url, idents):
to_reap = conn.execute(
"select u.username from all_users u where username "
"like 'TEST_%' and not exists (select username "
- "from v$session where username=u.username)")
- all_names = {username.lower() for (username, ) in to_reap}
+ "from v$session where username=u.username)"
+ )
+ all_names = {username.lower() for (username,) in to_reap}
to_drop = set()
for name in all_names:
if name.endswith("_ts1") or name.endswith("_ts2"):
@@ -348,28 +358,28 @@ def _reap_oracle_dbs(url, idents):
if _ora_drop_ignore(conn, username):
dropped += 1
log.info(
- "Dropped %d out of %d stale databases detected",
- dropped, total)
-
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
@_follower_url_from_main.for_db("oracle")
def _oracle_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
url.username = ident
- url.password = 'xe'
+ url.password = "xe"
return url
@_create_db.for_db("mssql")
def _mssql_create_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.execute("create database %s" % ident)
conn.execute(
- "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident)
+ "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident
+ )
conn.execute(
- "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident)
+ "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident
+ )
conn.execute("use %s" % ident)
conn.execute("create schema test_schema")
conn.execute("create schema test_schema_2")
@@ -377,10 +387,10 @@ def _mssql_create_db(cfg, eng, ident):
@_drop_db.for_db("mssql")
def _mssql_drop_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
_mssql_drop_ignore(conn, ident)
+
def _mssql_drop_ignore(conn, ident):
try:
# typically when this happens, we can't KILL the session anyway,
@@ -401,8 +411,7 @@ def _mssql_drop_ignore(conn, ident):
def _reap_mssql_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
log.info("identifiers in file: %s", ", ".join(idents))
@@ -410,8 +419,9 @@ def _reap_mssql_dbs(url, idents):
"select d.name from sys.databases as d where name "
"like 'TEST_%' and not exists (select session_id "
"from sys.dm_exec_sessions "
- "where database_id=d.database_id)")
- all_names = {dbname.lower() for (dbname, ) in to_reap}
+ "where database_id=d.database_id)"
+ )
+ all_names = {dbname.lower() for (dbname,) in to_reap}
to_drop = set()
for name in all_names:
if name in idents:
@@ -422,5 +432,5 @@ def _reap_mssql_dbs(url, idents):
if _mssql_drop_ignore(conn, dbname):
dropped += 1
log.info(
- "Dropped %d out of %d stale databases detected",
- dropped, total)
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
diff --git a/lib/sqlalchemy/testing/replay_fixture.py b/lib/sqlalchemy/testing/replay_fixture.py
index b50f52e3d..9832b07a2 100644
--- a/lib/sqlalchemy/testing/replay_fixture.py
+++ b/lib/sqlalchemy/testing/replay_fixture.py
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
class ReplayFixtureTest(fixtures.TestBase):
-
@contextlib.contextmanager
def _dummy_ctx(self, *arg, **kw):
yield
@@ -22,8 +21,8 @@ class ReplayFixtureTest(fixtures.TestBase):
creator = config.db.pool._creator
recorder = lambda: dbapi_session.recorder(creator())
engine = create_engine(
- config.db.url, creator=recorder,
- use_native_hstore=False)
+ config.db.url, creator=recorder, use_native_hstore=False
+ )
self.metadata = MetaData(engine)
self.engine = engine
self.session = Session(engine)
@@ -37,8 +36,8 @@ class ReplayFixtureTest(fixtures.TestBase):
player = lambda: dbapi_session.player()
engine = create_engine(
- config.db.url, creator=player,
- use_native_hstore=False)
+ config.db.url, creator=player, use_native_hstore=False
+ )
self.metadata = MetaData(engine)
self.engine = engine
@@ -74,21 +73,49 @@ class ReplayableSession(object):
NoAttribute = object()
if util.py2k:
- Natives = set([getattr(types, t)
- for t in dir(types) if not t.startswith('_')]).\
- difference([getattr(types, t)
- for t in ('FunctionType', 'BuiltinFunctionType',
- 'MethodType', 'BuiltinMethodType',
- 'LambdaType', 'UnboundMethodType',)])
+ Natives = set(
+ [getattr(types, t) for t in dir(types) if not t.startswith("_")]
+ ).difference(
+ [
+ getattr(types, t)
+ for t in (
+ "FunctionType",
+ "BuiltinFunctionType",
+ "MethodType",
+ "BuiltinMethodType",
+ "LambdaType",
+ "UnboundMethodType",
+ )
+ ]
+ )
else:
- Natives = set([getattr(types, t)
- for t in dir(types) if not t.startswith('_')]).\
- union([type(t) if not isinstance(t, type)
- else t for t in __builtins__.values()]).\
- difference([getattr(types, t)
- for t in ('FunctionType', 'BuiltinFunctionType',
- 'MethodType', 'BuiltinMethodType',
- 'LambdaType', )])
+ Natives = (
+ set(
+ [
+ getattr(types, t)
+ for t in dir(types)
+ if not t.startswith("_")
+ ]
+ )
+ .union(
+ [
+ type(t) if not isinstance(t, type) else t
+ for t in __builtins__.values()
+ ]
+ )
+ .difference(
+ [
+ getattr(types, t)
+ for t in (
+ "FunctionType",
+ "BuiltinFunctionType",
+ "MethodType",
+ "BuiltinMethodType",
+ "LambdaType",
+ )
+ ]
+ )
+ )
def __init__(self):
self.buffer = deque()
@@ -105,8 +132,10 @@ class ReplayableSession(object):
self._subject = subject
def __call__(self, *args, **kw):
- subject, buffer = [object.__getattribute__(self, x)
- for x in ('_subject', '_buffer')]
+ subject, buffer = [
+ object.__getattribute__(self, x)
+ for x in ("_subject", "_buffer")
+ ]
result = subject(*args, **kw)
if type(result) not in ReplayableSession.Natives:
@@ -126,8 +155,10 @@ class ReplayableSession(object):
except AttributeError:
pass
- subject, buffer = [object.__getattribute__(self, x)
- for x in ('_subject', '_buffer')]
+ subject, buffer = [
+ object.__getattribute__(self, x)
+ for x in ("_subject", "_buffer")
+ ]
try:
result = type(subject).__getattribute__(subject, key)
except AttributeError:
@@ -146,7 +177,7 @@ class ReplayableSession(object):
self._buffer = buffer
def __call__(self, *args, **kw):
- buffer = object.__getattribute__(self, '_buffer')
+ buffer = object.__getattribute__(self, "_buffer")
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
@@ -162,7 +193,7 @@ class ReplayableSession(object):
return object.__getattribute__(self, key)
except AttributeError:
pass
- buffer = object.__getattribute__(self, '_buffer')
+ buffer = object.__getattribute__(self, "_buffer")
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index 58df643f4..c96d26d32 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -26,7 +26,6 @@ class Requirements(object):
class SuiteRequirements(Requirements):
-
@property
def create_table(self):
"""target platform can emit basic CreateTable DDL."""
@@ -68,8 +67,8 @@ class SuiteRequirements(Requirements):
# somehow only_if([x, y]) isn't working here, negation/conjunctions
# getting confused.
return exclusions.only_if(
- lambda: self.on_update_cascade.enabled or
- self.deferrable_fks.enabled
+ lambda: self.on_update_cascade.enabled
+ or self.deferrable_fks.enabled
)
@property
@@ -231,22 +230,21 @@ class SuiteRequirements(Requirements):
def sane_rowcount(self):
return exclusions.skip_if(
lambda config: not config.db.dialect.supports_sane_rowcount,
- "driver doesn't support 'sane' rowcount"
+ "driver doesn't support 'sane' rowcount",
)
@property
def sane_multi_rowcount(self):
return exclusions.fails_if(
lambda config: not config.db.dialect.supports_sane_multi_rowcount,
- "driver %(driver)s %(doesnt_support)s 'sane' multi row count"
+ "driver %(driver)s %(doesnt_support)s 'sane' multi row count",
)
@property
def sane_rowcount_w_returning(self):
return exclusions.fails_if(
- lambda config:
- not config.db.dialect.supports_sane_rowcount_returning,
- "driver doesn't support 'sane' rowcount when returning is on"
+ lambda config: not config.db.dialect.supports_sane_rowcount_returning,
+ "driver doesn't support 'sane' rowcount when returning is on",
)
@property
@@ -255,9 +253,9 @@ class SuiteRequirements(Requirements):
INSERT DEFAULT VALUES or equivalent."""
return exclusions.only_if(
- lambda config: config.db.dialect.supports_empty_insert or
- config.db.dialect.supports_default_values,
- "empty inserts not supported"
+ lambda config: config.db.dialect.supports_empty_insert
+ or config.db.dialect.supports_default_values,
+ "empty inserts not supported",
)
@property
@@ -272,7 +270,7 @@ class SuiteRequirements(Requirements):
return exclusions.only_if(
lambda config: config.db.dialect.implicit_returning,
- "%(database)s %(does_support)s 'returning'"
+ "%(database)s %(does_support)s 'returning'",
)
@property
@@ -297,7 +295,7 @@ class SuiteRequirements(Requirements):
return exclusions.skip_if(
lambda config: not config.db.dialect.requires_name_normalize,
- "Backend does not require denormalized names."
+ "Backend does not require denormalized names.",
)
@property
@@ -307,7 +305,7 @@ class SuiteRequirements(Requirements):
return exclusions.skip_if(
lambda config: not config.db.dialect.supports_multivalues_insert,
- "Backend does not support multirow inserts."
+ "Backend does not support multirow inserts.",
)
@property
@@ -355,27 +353,32 @@ class SuiteRequirements(Requirements):
def server_side_cursors(self):
"""Target dialect must support server side cursors."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_server_side_cursors
- ], "no server side cursors support")
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_server_side_cursors],
+ "no server side cursors support",
+ )
@property
def sequences(self):
"""Target database must support SEQUENCEs."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_sequences
- ], "no sequence support")
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_sequences],
+ "no sequence support",
+ )
@property
def sequences_optional(self):
"""Target database supports sequences, but also optionally
as a means of generating new PK values."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_sequences and
- config.db.dialect.sequences_optional
- ], "no sequence support, or sequences not optional")
+ return exclusions.only_if(
+ [
+ lambda config: config.db.dialect.supports_sequences
+ and config.db.dialect.sequences_optional
+ ],
+ "no sequence support, or sequences not optional",
+ )
@property
def reflects_pk_names(self):
@@ -841,7 +844,8 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
- lambda config: config.options.low_connections)
+ lambda config: config.options.low_connections
+ )
@property
def timing_intensive(self):
@@ -859,37 +863,37 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
lambda config: util.py3k and config.options.has_coverage,
- "Stability issues with coverage + py3k"
+ "Stability issues with coverage + py3k",
)
@property
def python2(self):
return exclusions.skip_if(
lambda: sys.version_info >= (3,),
- "Python version 2.xx is required."
+ "Python version 2.xx is required.",
)
@property
def python3(self):
return exclusions.skip_if(
- lambda: sys.version_info < (3,),
- "Python version 3.xx is required."
+ lambda: sys.version_info < (3,), "Python version 3.xx is required."
)
@property
def cpython(self):
return exclusions.only_if(
- lambda: util.cpython,
- "cPython interpreter needed"
+ lambda: util.cpython, "cPython interpreter needed"
)
@property
def non_broken_pickle(self):
from sqlalchemy.util import pickle
+
return exclusions.only_if(
- lambda: not util.pypy and pickle.__name__ == 'cPickle'
- or sys.version_info >= (3, 2),
- "Needs cPickle+cPython or newer Python 3 pickle"
+ lambda: not util.pypy
+ and pickle.__name__ == "cPickle"
+ or sys.version_info >= (3, 2),
+ "Needs cPickle+cPython or newer Python 3 pickle",
)
@property
@@ -910,7 +914,7 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
lambda config: config.options.has_coverage,
- "Issues observed when coverage is enabled"
+ "Issues observed when coverage is enabled",
)
def _has_mysql_on_windows(self, config):
@@ -931,8 +935,9 @@ class SuiteRequirements(Requirements):
def _has_sqlite(self):
from sqlalchemy import create_engine
+
try:
- create_engine('sqlite://')
+ create_engine("sqlite://")
return True
except ImportError:
return False
@@ -940,6 +945,7 @@ class SuiteRequirements(Requirements):
def _has_cextensions(self):
try:
from sqlalchemy import cresultproxy, cprocessors
+
return True
except ImportError:
return False
diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py
index 87be0749c..6aa820fd5 100644
--- a/lib/sqlalchemy/testing/runner.py
+++ b/lib/sqlalchemy/testing/runner.py
@@ -47,4 +47,4 @@ def setup_py_test():
to nose.
"""
- nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner'])
+ nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"])
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
index 401c8cbb7..b345a9487 100644
--- a/lib/sqlalchemy/testing/schema.py
+++ b/lib/sqlalchemy/testing/schema.py
@@ -9,7 +9,7 @@ from . import exclusions
from .. import schema, event
from . import config
-__all__ = 'Table', 'Column',
+__all__ = "Table", "Column"
table_options = {}
@@ -17,30 +17,35 @@ table_options = {}
def Table(*args, **kw):
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
- test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')}
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
kw.update(table_options)
- if exclusions.against(config._current, 'mysql'):
- if 'mysql_engine' not in kw and 'mysql_type' not in kw and \
- "autoload_with" not in kw:
- if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
- kw['mysql_engine'] = 'InnoDB'
+ if exclusions.against(config._current, "mysql"):
+ if (
+ "mysql_engine" not in kw
+ and "mysql_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mysql_engine"] = "InnoDB"
else:
- kw['mysql_engine'] = 'MyISAM'
+ kw["mysql_engine"] = "MyISAM"
# Apply some default cascading rules for self-referential foreign keys.
# MySQL InnoDB has some issues around seleting self-refs too.
- if exclusions.against(config._current, 'firebird'):
+ if exclusions.against(config._current, "firebird"):
table_name = args[0]
- unpack = (config.db.dialect.
- identifier_preparer.unformat_identifiers)
+ unpack = config.db.dialect.identifier_preparer.unformat_identifiers
# Only going after ForeignKeys in Columns. May need to
# expand to ForeignKeyConstraint too.
- fks = [fk
- for col in args if isinstance(col, schema.Column)
- for fk in col.foreign_keys]
+ fks = [
+ fk
+ for col in args
+ if isinstance(col, schema.Column)
+ for fk in col.foreign_keys
+ ]
for fk in fks:
# root around in raw spec
@@ -54,9 +59,9 @@ def Table(*args, **kw):
name = unpack(ref)[0]
if name == table_name:
if fk.ondelete is None:
- fk.ondelete = 'CASCADE'
+ fk.ondelete = "CASCADE"
if fk.onupdate is None:
- fk.onupdate = 'CASCADE'
+ fk.onupdate = "CASCADE"
return schema.Table(*args, **kw)
@@ -64,37 +69,46 @@ def Table(*args, **kw):
def Column(*args, **kw):
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
- test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')}
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
col = schema.Column(*args, **kw)
- if test_opts.get('test_needs_autoincrement', False) and \
- kw.get('primary_key', False):
+ if test_opts.get("test_needs_autoincrement", False) and kw.get(
+ "primary_key", False
+ ):
if col.default is None and col.server_default is None:
col.autoincrement = True
# allow any test suite to pick up on this
- col.info['test_needs_autoincrement'] = True
+ col.info["test_needs_autoincrement"] = True
# hardcoded rule for firebird, oracle; this should
# be moved out
- if exclusions.against(config._current, 'firebird', 'oracle'):
+ if exclusions.against(config._current, "firebird", "oracle"):
+
def add_seq(c, tbl):
c._init_items(
- schema.Sequence(_truncate_name(
- config.db.dialect, tbl.name + '_' + c.name + '_seq'),
- optional=True)
+ schema.Sequence(
+ _truncate_name(
+ config.db.dialect, tbl.name + "_" + c.name + "_seq"
+ ),
+ optional=True,
+ )
)
- event.listen(col, 'after_parent_attach', add_seq, propagate=True)
+
+ event.listen(col, "after_parent_attach", add_seq, propagate=True)
return col
def _truncate_name(dialect, name):
if len(name) > dialect.max_identifier_length:
- return name[0:max(dialect.max_identifier_length - 6, 0)] + \
- "_" + hex(hash(name) % 64)[2:]
+ return (
+ name[0 : max(dialect.max_identifier_length - 6, 0)]
+ + "_"
+ + hex(hash(name) % 64)[2:]
+ )
else:
return name
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
index 748d9722d..a4e142c5a 100644
--- a/lib/sqlalchemy/testing/suite/__init__.py
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -1,4 +1,3 @@
-
from sqlalchemy.testing.suite.test_cte import *
from sqlalchemy.testing.suite.test_dialect import *
from sqlalchemy.testing.suite.test_ddl import *
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
index cc72278e6..d2f35933b 100644
--- a/lib/sqlalchemy/testing/suite/test_cte.py
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -10,22 +10,28 @@ from ..schema import Table, Column
class CTETest(fixtures.TablesTest):
__backend__ = True
- __requires__ = 'ctes',
+ __requires__ = ("ctes",)
- run_inserts = 'each'
- run_deletes = 'each'
+ run_inserts = "each"
+ run_deletes = "each"
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)),
- Column("parent_id", ForeignKey("some_table.id")))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", ForeignKey("some_table.id")),
+ )
- Table("some_other_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)),
- Column("parent_id", Integer))
+ Table(
+ "some_other_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -36,28 +42,33 @@ class CTETest(fixtures.TablesTest):
{"id": 2, "data": "d2", "parent_id": 1},
{"id": 3, "data": "d3", "parent_id": 1},
{"id": 4, "data": "d4", "parent_id": 3},
- {"id": 5, "data": "d5", "parent_id": 3}
- ]
+ {"id": 5, "data": "d5", "parent_id": 3},
+ ],
)
def test_select_nonrecursive_round_trip(self):
some_table = self.tables.some_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
result = conn.execute(
select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
)
- eq_(result.fetchall(), [("d4", )])
+ eq_(result.fetchall(), [("d4",)])
def test_select_recursive_round_trip(self):
some_table = self.tables.some_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])).cte(
- "some_cte", recursive=True)
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte", recursive=True)
+ )
cte_alias = cte.alias("c1")
st1 = some_table.alias()
@@ -67,12 +78,13 @@ class CTETest(fixtures.TablesTest):
select([st1]).where(st1.c.id == cte_alias.c.parent_id)
)
result = conn.execute(
- select([cte.c.data]).where(
- cte.c.data != "d2").order_by(cte.c.data.desc())
+ select([cte.c.data])
+ .where(cte.c.data != "d2")
+ .order_by(cte.c.data.desc())
)
eq_(
result.fetchall(),
- [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)]
+ [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
)
def test_insert_from_select_round_trip(self):
@@ -80,20 +92,21 @@ class CTETest(fixtures.TablesTest):
some_other_table = self.tables.some_other_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.insert().from_select(
- ["id", "data", "parent_id"],
- select([cte])
+ ["id", "data", "parent_id"], select([cte])
)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)]
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
)
@testing.requires.ctes_with_update_delete
@@ -105,27 +118,31 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
- some_other_table.update().values(parent_id=5).where(
- some_other_table.c.data == cte.c.data
- )
+ some_other_table.update()
+ .values(parent_id=5)
+ .where(some_other_table.c.data == cte.c.data)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
[
- (1, "d1", None), (2, "d2", 5),
- (3, "d3", 5), (4, "d4", 5), (5, "d5", 3)
- ]
+ (1, "d1", None),
+ (2, "d2", 5),
+ (3, "d3", 5),
+ (4, "d4", 5),
+ (5, "d5", 3),
+ ],
)
@testing.requires.ctes_with_update_delete
@@ -137,14 +154,15 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.delete().where(
some_other_table.c.data == cte.c.data
@@ -154,9 +172,7 @@ class CTETest(fixtures.TablesTest):
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [
- (1, "d1", None), (5, "d5", 3)
- ]
+ [(1, "d1", None), (5, "d5", 3)],
)
@testing.requires.ctes_with_update_delete
@@ -168,26 +184,26 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.delete().where(
- some_other_table.c.data ==
- select([cte.c.data]).where(
- cte.c.id == some_other_table.c.id)
+ some_other_table.c.data
+ == select([cte.c.data]).where(
+ cte.c.id == some_other_table.c.id
+ )
)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [
- (1, "d1", None), (5, "d5", 3)
- ]
+ [(1, "d1", None), (5, "d5", 3)],
)
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
index 1d8010c8a..7c44388d4 100644
--- a/lib/sqlalchemy/testing/suite/test_ddl.py
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -1,5 +1,3 @@
-
-
from .. import fixtures, config, util
from ..config import requirements
from ..assertions import eq_
@@ -11,55 +9,47 @@ class TableDDLTest(fixtures.TestBase):
__backend__ = True
def _simple_fixture(self):
- return Table('test_table', self.metadata,
- Column('id', Integer, primary_key=True,
- autoincrement=False),
- Column('data', String(50))
- )
+ return Table(
+ "test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
def _underscore_fixture(self):
- return Table('_test_table', self.metadata,
- Column('id', Integer, primary_key=True,
- autoincrement=False),
- Column('_data', String(50))
- )
+ return Table(
+ "_test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("_data", String(50)),
+ )
def _simple_roundtrip(self, table):
with config.db.begin() as conn:
- conn.execute(table.insert().values((1, 'some data')))
+ conn.execute(table.insert().values((1, "some data")))
result = conn.execute(table.select())
- eq_(
- result.first(),
- (1, 'some data')
- )
+ eq_(result.first(), (1, "some data"))
@requirements.create_table
@util.provide_metadata
def test_create_table(self):
table = self._simple_fixture()
- table.create(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.drop_table
@util.provide_metadata
def test_drop_table(self):
table = self._simple_fixture()
- table.create(
- config.db, checkfirst=False
- )
- table.drop(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
+ table.drop(config.db, checkfirst=False)
@requirements.create_table
@util.provide_metadata
def test_underscore_names(self):
table = self._underscore_fixture()
- table.create(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
-__all__ = ('TableDDLTest', )
+
+__all__ = ("TableDDLTest",)
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
index 2c5dd0e36..5e589f3b8 100644
--- a/lib/sqlalchemy/testing/suite/test_dialect.py
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -15,16 +15,19 @@ class ExceptionTest(fixtures.TablesTest):
specific exceptions from real round trips, we need to be conservative.
"""
- run_deletes = 'each'
+
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
@requirements.duplicate_key_raises_integrity_error
def test_integrity_error(self):
@@ -33,15 +36,14 @@ class ExceptionTest(fixtures.TablesTest):
trans = conn.begin()
conn.execute(
- self.tables.manual_pk.insert(),
- {'id': 1, 'data': 'd1'}
+ self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
)
assert_raises(
exc.IntegrityError,
conn.execute,
self.tables.manual_pk.insert(),
- {'id': 1, 'data': 'd1'}
+ {"id": 1, "data": "d1"},
)
trans.rollback()
@@ -49,38 +51,39 @@ class ExceptionTest(fixtures.TablesTest):
class AutocommitTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
- __requires__ = 'autocommit',
+ __requires__ = ("autocommit",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('some_table', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50)),
- test_needs_acid=True
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ test_needs_acid=True,
+ )
def _test_conn_autocommits(self, conn, autocommit):
trans = conn.begin()
conn.execute(
- self.tables.some_table.insert(),
- {"id": 1, "data": "some data"}
+ self.tables.some_table.insert(), {"id": 1, "data": "some data"}
)
trans.rollback()
eq_(
conn.scalar(select([self.tables.some_table.c.id])),
- 1 if autocommit else None
+ 1 if autocommit else None,
)
conn.execute(self.tables.some_table.delete())
def test_autocommit_on(self):
conn = config.db.connect()
- c2 = conn.execution_options(isolation_level='AUTOCOMMIT')
+ c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(c2, True)
conn.invalidate()
self._test_conn_autocommits(conn, False)
@@ -98,7 +101,7 @@ class EscapingTest(fixtures.TestBase):
"""
m = self.metadata
- t = Table('t', m, Column('data', String(50)))
+ t = Table("t", m, Column("data", String(50)))
t.create(config.db)
with config.db.begin() as conn:
conn.execute(t.insert(), dict(data="some % value"))
@@ -107,14 +110,17 @@ class EscapingTest(fixtures.TestBase):
eq_(
conn.scalar(
select([t.c.data]).where(
- t.c.data == literal_column("'some % value'"))
+ t.c.data == literal_column("'some % value'")
+ )
),
- "some % value"
+ "some % value",
)
eq_(
conn.scalar(
select([t.c.data]).where(
- t.c.data == literal_column("'some %% other value'"))
- ), "some %% other value"
+ t.c.data == literal_column("'some %% other value'")
+ )
+ ),
+ "some %% other value",
)
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
index c0b6b18eb..6257451eb 100644
--- a/lib/sqlalchemy/testing/suite/test_insert.py
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -10,53 +10,48 @@ from ..schema import Table, Column
class LastrowidTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
- __requires__ = 'implements_get_lastrowid', 'autoincrement_insert'
+ __requires__ = "implements_get_lastrowid", "autoincrement_insert"
__engine_options__ = {"implicit_returning": False}
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (config.db.dialect.default_sequence_base, "some data")
- )
+ eq_(row, (config.db.dialect.default_sequence_base, "some data"))
def test_autoincrement_on_insert(self):
- config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.autoinc_pk.insert(), data="some data")
self._assert_round_trip(self.tables.autoinc_pk, config.db)
def test_last_inserted_id(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- r.inserted_primary_key,
- [pk]
- )
+ eq_(r.inserted_primary_key, [pk])
# failed on pypy1.9 but seems to be OK on pypy 2.1
# @exclusions.fails_if(lambda: util.pypy,
@@ -65,50 +60,57 @@ class LastrowidTest(fixtures.TablesTest):
@requirements.dbapi_lastrowid
def test_native_lastrowid_autoinc(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
lastrowid = r.lastrowid
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- lastrowid, pk
- )
+ eq_(lastrowid, pk)
class InsertBehaviorTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
- Table('includes_defaults', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50)),
- Column('x', Integer, default=5),
- Column('y', Integer,
- default=literal_column("2", type_=Integer) + literal(2)))
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+ Table(
+ "includes_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ Column("x", Integer, default=5),
+ Column(
+ "y",
+ Integer,
+ default=literal_column("2", type_=Integer) + literal(2),
+ ),
+ )
def test_autoclose_on_insert(self):
if requirements.returning.enabled:
engine = engines.testing_engine(
- options={'implicit_returning': False})
+ options={"implicit_returning": False}
+ )
else:
engine = config.db
- r = engine.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ r = engine.execute(self.tables.autoinc_pk.insert(), data="some data")
assert r._soft_closed
assert not r.closed
assert r.is_insert
@@ -117,8 +119,7 @@ class InsertBehaviorTest(fixtures.TablesTest):
@requirements.returning
def test_autoclose_on_insert_implicit_returning(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
assert r._soft_closed
assert not r.closed
@@ -127,15 +128,14 @@ class InsertBehaviorTest(fixtures.TablesTest):
@requirements.empty_inserts
def test_empty_insert(self):
- r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- )
+ r = config.db.execute(self.tables.autoinc_pk.insert())
assert r._soft_closed
assert not r.closed
r = config.db.execute(
- self.tables.autoinc_pk.select().
- where(self.tables.autoinc_pk.c.id != None)
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
)
assert len(r.fetchall())
@@ -150,15 +150,15 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
result = config.db.execute(
- dest_table.insert().
- from_select(
+ dest_table.insert().from_select(
("data",),
- select([src_table.c.data]).
- where(src_table.c.data.in_(["data2", "data3"]))
+ select([src_table.c.data]).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
)
)
@@ -167,7 +167,7 @@ class InsertBehaviorTest(fixtures.TablesTest):
result = config.db.execute(
select([dest_table.c.data]).order_by(dest_table.c.data)
)
- eq_(result.fetchall(), [("data2", ), ("data3", )])
+ eq_(result.fetchall(), [("data2",), ("data3",)])
@requirements.insert_from_select
def test_insert_from_select_autoinc_no_rows(self):
@@ -175,11 +175,11 @@ class InsertBehaviorTest(fixtures.TablesTest):
dest_table = self.tables.autoinc_pk
result = config.db.execute(
- dest_table.insert().
- from_select(
+ dest_table.insert().from_select(
("data",),
- select([src_table.c.data]).
- where(src_table.c.data.in_(["data2", "data3"]))
+ select([src_table.c.data]).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
)
)
eq_(result.inserted_primary_key, [None])
@@ -199,23 +199,23 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
config.db.execute(
- table.insert(inline=True).
- from_select(("id", "data",),
- select([table.c.id + 5, table.c.data]).
- where(table.c.data.in_(["data2", "data3"]))
- ),
+ table.insert(inline=True).from_select(
+ ("id", "data"),
+ select([table.c.id + 5, table.c.data]).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
)
eq_(
config.db.execute(
select([table.c.data]).order_by(table.c.data)
).fetchall(),
- [("data1", ), ("data2", ), ("data2", ),
- ("data3", ), ("data3", )]
+ [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
)
@requirements.insert_from_select
@@ -227,56 +227,60 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
config.db.execute(
- table.insert(inline=True).
- from_select(("id", "data",),
- select([table.c.id + 5, table.c.data]).
- where(table.c.data.in_(["data2", "data3"]))
- ),
+ table.insert(inline=True).from_select(
+ ("id", "data"),
+ select([table.c.id + 5, table.c.data]).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
)
eq_(
config.db.execute(
select([table]).order_by(table.c.data, table.c.id)
).fetchall(),
- [(1, 'data1', 5, 4), (2, 'data2', 5, 4),
- (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
+ [
+ (1, "data1", 5, 4),
+ (2, "data2", 5, 4),
+ (7, "data2", 5, 4),
+ (3, "data3", 5, 4),
+ (8, "data3", 5, 4),
+ ],
)
class ReturningTest(fixtures.TablesTest):
- run_create_tables = 'each'
- __requires__ = 'returning', 'autoincrement_insert'
+ run_create_tables = "each"
+ __requires__ = "returning", "autoincrement_insert"
__backend__ = True
__engine_options__ = {"implicit_returning": True}
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (config.db.dialect.default_sequence_base, "some data")
- )
+ eq_(row, (config.db.dialect.default_sequence_base, "some data"))
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
@requirements.fetch_rows_post_commit
def test_explicit_returning_pk_autocommit(self):
engine = config.db
table = self.tables.autoinc_pk
r = engine.execute(
- table.insert().returning(
- table.c.id),
- data="some data"
+ table.insert().returning(table.c.id), data="some data"
)
pk = r.first()[0]
fetched_pk = config.db.scalar(select([table.c.id]))
@@ -287,9 +291,7 @@ class ReturningTest(fixtures.TablesTest):
table = self.tables.autoinc_pk
with engine.begin() as conn:
r = conn.execute(
- table.insert().returning(
- table.c.id),
- data="some data"
+ table.insert().returning(table.c.id), data="some data"
)
pk = r.first()[0]
fetched_pk = config.db.scalar(select([table.c.id]))
@@ -297,23 +299,16 @@ class ReturningTest(fixtures.TablesTest):
def test_autoincrement_on_insert_implcit_returning(self):
- config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.autoinc_pk.insert(), data="some data")
self._assert_round_trip(self.tables.autoinc_pk, config.db)
def test_last_inserted_id_implicit_returning(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- r.inserted_primary_key,
- [pk]
- )
+ eq_(r.inserted_primary_key, [pk])
-__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest')
+__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 00a5aac01..bfed5f1ab 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -1,5 +1,3 @@
-
-
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
from sqlalchemy import types as sql_types
@@ -26,10 +24,12 @@ class HasTableTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('test_table', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
def test_has_table(self):
with config.db.begin() as conn:
@@ -46,8 +46,10 @@ class ComponentReflectionTest(fixtures.TablesTest):
def setup_bind(cls):
if config.requirements.independent_connections.enabled:
from sqlalchemy import pool
+
return engines.testing_engine(
- options=dict(poolclass=pool.StaticPool))
+ options=dict(poolclass=pool.StaticPool)
+ )
else:
return config.db
@@ -65,86 +67,109 @@ class ComponentReflectionTest(fixtures.TablesTest):
schema_prefix = ""
if testing.requires.self_referential_foreign_keys.enabled:
- users = Table('users', metadata,
- Column('user_id', sa.INT, primary_key=True),
- Column('test1', sa.CHAR(5), nullable=False),
- Column('test2', sa.Float(5), nullable=False),
- Column('parent_user_id', sa.Integer,
- sa.ForeignKey('%susers.user_id' %
- schema_prefix,
- name='user_id_fk')),
- schema=schema,
- test_needs_fk=True,
- )
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ Column(
+ "parent_user_id",
+ sa.Integer,
+ sa.ForeignKey(
+ "%susers.user_id" % schema_prefix, name="user_id_fk"
+ ),
+ ),
+ schema=schema,
+ test_needs_fk=True,
+ )
else:
- users = Table('users', metadata,
- Column('user_id', sa.INT, primary_key=True),
- Column('test1', sa.CHAR(5), nullable=False),
- Column('test2', sa.Float(5), nullable=False),
- schema=schema,
- test_needs_fk=True,
- )
-
- Table("dingalings", metadata,
- Column('dingaling_id', sa.Integer, primary_key=True),
- Column('address_id', sa.Integer,
- sa.ForeignKey('%semail_addresses.address_id' %
- schema_prefix)),
- Column('data', sa.String(30)),
- schema=schema,
- test_needs_fk=True,
- )
- Table('email_addresses', metadata,
- Column('address_id', sa.Integer),
- Column('remote_user_id', sa.Integer,
- sa.ForeignKey(users.c.user_id)),
- Column('email_address', sa.String(20)),
- sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'),
- schema=schema,
- test_needs_fk=True,
- )
- Table('comment_test', metadata,
- Column('id', sa.Integer, primary_key=True, comment='id comment'),
- Column('data', sa.String(20), comment='data % comment'),
- Column(
- 'd2', sa.String(20),
- comment=r"""Comment types type speedily ' " \ '' Fun!"""),
- schema=schema,
- comment=r"""the test % ' " \ table comment""")
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ Table(
+ "dingalings",
+ metadata,
+ Column("dingaling_id", sa.Integer, primary_key=True),
+ Column(
+ "address_id",
+ sa.Integer,
+ sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ ),
+ Column("data", sa.String(30)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "email_addresses",
+ metadata,
+ Column("address_id", sa.Integer),
+ Column(
+ "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
+ ),
+ Column("email_address", sa.String(20)),
+ sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "comment_test",
+ metadata,
+ Column("id", sa.Integer, primary_key=True, comment="id comment"),
+ Column("data", sa.String(20), comment="data % comment"),
+ Column(
+ "d2",
+ sa.String(20),
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ schema=schema,
+ comment=r"""the test % ' " \ table comment""",
+ )
if testing.requires.cross_schema_fk_reflection.enabled:
if schema is None:
Table(
- 'local_table', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('data', sa.String(20)),
+ "local_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
Column(
- 'remote_id',
+ "remote_id",
ForeignKey(
- '%s.remote_table_2.id' %
- testing.config.test_schema)
+ "%s.remote_table_2.id" % testing.config.test_schema
+ ),
),
test_needs_fk=True,
- schema=config.db.dialect.default_schema_name
+ schema=config.db.dialect.default_schema_name,
)
else:
Table(
- 'remote_table', metadata,
- Column('id', sa.Integer, primary_key=True),
+ "remote_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
Column(
- 'local_id',
+ "local_id",
ForeignKey(
- '%s.local_table.id' %
- config.db.dialect.default_schema_name)
+ "%s.local_table.id"
+ % config.db.dialect.default_schema_name
+ ),
),
- Column('data', sa.String(20)),
+ Column("data", sa.String(20)),
schema=schema,
test_needs_fk=True,
)
Table(
- 'remote_table_2', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('data', sa.String(20)),
+ "remote_table_2",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
schema=schema,
test_needs_fk=True,
)
@@ -155,19 +180,21 @@ class ComponentReflectionTest(fixtures.TablesTest):
if not schema:
# test_needs_fk is at the moment to force MySQL InnoDB
noncol_idx_test_nopk = Table(
- 'noncol_idx_test_nopk', metadata,
- Column('q', sa.String(5)),
+ "noncol_idx_test_nopk",
+ metadata,
+ Column("q", sa.String(5)),
test_needs_fk=True,
)
noncol_idx_test_pk = Table(
- 'noncol_idx_test_pk', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('q', sa.String(5)),
+ "noncol_idx_test_pk",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("q", sa.String(5)),
test_needs_fk=True,
)
- Index('noncol_idx_nopk', noncol_idx_test_nopk.c.q.desc())
- Index('noncol_idx_pk', noncol_idx_test_pk.c.q.desc())
+ Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
+ Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
if testing.requires.view_column_reflection.enabled:
cls.define_views(metadata, schema)
@@ -180,34 +207,35 @@ class ComponentReflectionTest(fixtures.TablesTest):
# temp table fixture
if testing.against("oracle"):
kw = {
- 'prefixes': ["GLOBAL TEMPORARY"],
- 'oracle_on_commit': 'PRESERVE ROWS'
+ "prefixes": ["GLOBAL TEMPORARY"],
+ "oracle_on_commit": "PRESERVE ROWS",
}
else:
- kw = {
- 'prefixes': ["TEMPORARY"],
- }
+ kw = {"prefixes": ["TEMPORARY"]}
user_tmp = Table(
- "user_tmp", metadata,
+ "user_tmp",
+ metadata,
Column("id", sa.INT, primary_key=True),
- Column('name', sa.VARCHAR(50)),
- Column('foo', sa.INT),
- sa.UniqueConstraint('name', name='user_tmp_uq'),
+ Column("name", sa.VARCHAR(50)),
+ Column("foo", sa.INT),
+ sa.UniqueConstraint("name", name="user_tmp_uq"),
sa.Index("user_tmp_ix", "foo"),
**kw
)
- if testing.requires.view_reflection.enabled and \
- testing.requires.temporary_views.enabled:
- event.listen(
- user_tmp, "after_create",
- DDL("create temporary view user_tmp_v as "
- "select * from user_tmp")
- )
+ if (
+ testing.requires.view_reflection.enabled
+ and testing.requires.temporary_views.enabled
+ ):
event.listen(
- user_tmp, "before_drop",
- DDL("drop view user_tmp_v")
+ user_tmp,
+ "after_create",
+ DDL(
+ "create temporary view user_tmp_v as "
+ "select * from user_tmp"
+ ),
)
+ event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
@classmethod
def define_index(cls, metadata, users):
@@ -216,23 +244,19 @@ class ComponentReflectionTest(fixtures.TablesTest):
@classmethod
def define_views(cls, metadata, schema):
- for table_name in ('users', 'email_addresses'):
+ for table_name in ("users", "email_addresses"):
fullname = table_name
if schema:
fullname = "%s.%s" % (schema, table_name)
- view_name = fullname + '_v'
+ view_name = fullname + "_v"
query = "CREATE VIEW %s AS SELECT * FROM %s" % (
- view_name, fullname)
-
- event.listen(
- metadata,
- "after_create",
- DDL(query)
+ view_name,
+ fullname,
)
+
+ event.listen(metadata, "after_create", DDL(query))
event.listen(
- metadata,
- "before_drop",
- DDL("DROP VIEW %s" % view_name)
+ metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
)
@testing.requires.schema_reflection
@@ -244,9 +268,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schema_reflection
def test_dialect_initialize(self):
engine = engines.testing_engine()
- assert not hasattr(engine.dialect, 'default_schema_name')
+ assert not hasattr(engine.dialect, "default_schema_name")
inspect(engine)
- assert hasattr(engine.dialect, 'default_schema_name')
+ assert hasattr(engine.dialect, "default_schema_name")
@testing.requires.schema_reflection
def test_get_default_schema_name(self):
@@ -254,40 +278,49 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(insp.default_schema_name, testing.db.dialect.default_schema_name)
@testing.provide_metadata
- def _test_get_table_names(self, schema=None, table_type='table',
- order_by=None):
+ def _test_get_table_names(
+ self, schema=None, table_type="table", order_by=None
+ ):
_ignore_tables = [
- 'comment_test', 'noncol_idx_test_pk', 'noncol_idx_test_nopk',
- 'local_table', 'remote_table', 'remote_table_2'
+ "comment_test",
+ "noncol_idx_test_pk",
+ "noncol_idx_test_nopk",
+ "local_table",
+ "remote_table",
+ "remote_table_2",
]
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
- if table_type == 'view':
+ if table_type == "view":
table_names = insp.get_view_names(schema)
table_names.sort()
- answer = ['email_addresses_v', 'users_v']
+ answer = ["email_addresses_v", "users_v"]
eq_(sorted(table_names), answer)
else:
table_names = [
- t for t in insp.get_table_names(
- schema,
- order_by=order_by) if t not in _ignore_tables]
+ t
+ for t in insp.get_table_names(schema, order_by=order_by)
+ if t not in _ignore_tables
+ ]
- if order_by == 'foreign_key':
- answer = ['users', 'email_addresses', 'dingalings']
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
eq_(table_names, answer)
else:
- answer = ['dingalings', 'email_addresses', 'users']
+ answer = ["dingalings", "email_addresses", "users"]
eq_(sorted(table_names), answer)
@testing.requires.temp_table_names
def test_get_temp_table_names(self):
insp = inspect(self.bind)
temp_table_names = insp.get_temp_table_names()
- eq_(sorted(temp_table_names), ['user_tmp'])
+ eq_(sorted(temp_table_names), ["user_tmp"])
@testing.requires.view_reflection
@testing.requires.temp_table_names
@@ -295,7 +328,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
def test_get_temp_view_names(self):
insp = inspect(self.bind)
temp_table_names = insp.get_temp_view_names()
- eq_(sorted(temp_table_names), ['user_tmp_v'])
+ eq_(sorted(temp_table_names), ["user_tmp_v"])
@testing.requires.table_reflection
def test_get_table_names(self):
@@ -304,7 +337,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.table_reflection
@testing.requires.foreign_key_constraint_reflection
def test_get_table_names_fks(self):
- self._test_get_table_names(order_by='foreign_key')
+ self._test_get_table_names(order_by="foreign_key")
@testing.requires.comment_reflection
def test_get_comments(self):
@@ -320,26 +353,24 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(
insp.get_table_comment("comment_test", schema=schema),
- {"text": r"""the test % ' " \ table comment"""}
+ {"text": r"""the test % ' " \ table comment"""},
)
- eq_(
- insp.get_table_comment("users", schema=schema),
- {"text": None}
- )
+ eq_(insp.get_table_comment("users", schema=schema), {"text": None})
eq_(
[
- {"name": rec['name'], "comment": rec['comment']}
- for rec in
- insp.get_columns("comment_test", schema=schema)
+ {"name": rec["name"], "comment": rec["comment"]}
+ for rec in insp.get_columns("comment_test", schema=schema)
],
[
- {'comment': 'id comment', 'name': 'id'},
- {'comment': 'data % comment', 'name': 'data'},
- {'comment': r"""Comment types type speedily ' " \ '' Fun!""",
- 'name': 'd2'}
- ]
+ {"comment": "id comment", "name": "id"},
+ {"comment": "data % comment", "name": "data"},
+ {
+ "comment": r"""Comment types type speedily ' " \ '' Fun!""",
+ "name": "d2",
+ },
+ ],
)
@testing.requires.table_reflection
@@ -349,30 +380,33 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.view_column_reflection
def test_get_view_names(self):
- self._test_get_table_names(table_type='view')
+ self._test_get_table_names(table_type="view")
@testing.requires.view_column_reflection
@testing.requires.schemas
def test_get_view_names_with_schema(self):
self._test_get_table_names(
- testing.config.test_schema, table_type='view')
+ testing.config.test_schema, table_type="view"
+ )
@testing.requires.table_reflection
@testing.requires.view_column_reflection
def test_get_tables_and_views(self):
self._test_get_table_names()
- self._test_get_table_names(table_type='view')
+ self._test_get_table_names(table_type="view")
- def _test_get_columns(self, schema=None, table_type='table'):
+ def _test_get_columns(self, schema=None, table_type="table"):
meta = MetaData(testing.db)
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
- table_names = ['users', 'email_addresses']
- if table_type == 'view':
- table_names = ['users_v', 'email_addresses_v']
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
+ table_names = ["users", "email_addresses"]
+ if table_type == "view":
+ table_names = ["users_v", "email_addresses_v"]
insp = inspect(meta.bind)
- for table_name, table in zip(table_names, (users,
- addresses)):
+ for table_name, table in zip(table_names, (users, addresses)):
schema_name = schema
cols = insp.get_columns(table_name, schema=schema_name)
self.assert_(len(cols) > 0, len(cols))
@@ -380,36 +414,46 @@ class ComponentReflectionTest(fixtures.TablesTest):
# should be in order
for i, col in enumerate(table.columns):
- eq_(col.name, cols[i]['name'])
- ctype = cols[i]['type'].__class__
+ eq_(col.name, cols[i]["name"])
+ ctype = cols[i]["type"].__class__
ctype_def = col.type
if isinstance(ctype_def, sa.types.TypeEngine):
ctype_def = ctype_def.__class__
# Oracle returns Date for DateTime.
- if testing.against('oracle') and ctype_def \
- in (sql_types.Date, sql_types.DateTime):
+ if testing.against("oracle") and ctype_def in (
+ sql_types.Date,
+ sql_types.DateTime,
+ ):
ctype_def = sql_types.Date
# assert that the desired type and return type share
# a base within one of the generic types.
- self.assert_(len(set(ctype.__mro__).
- intersection(ctype_def.__mro__).
- intersection([
- sql_types.Integer,
- sql_types.Numeric,
- sql_types.DateTime,
- sql_types.Date,
- sql_types.Time,
- sql_types.String,
- sql_types._Binary,
- ])) > 0, '%s(%s), %s(%s)' %
- (col.name, col.type, cols[i]['name'], ctype))
+ self.assert_(
+ len(
+ set(ctype.__mro__)
+ .intersection(ctype_def.__mro__)
+ .intersection(
+ [
+ sql_types.Integer,
+ sql_types.Numeric,
+ sql_types.DateTime,
+ sql_types.Date,
+ sql_types.Time,
+ sql_types.String,
+ sql_types._Binary,
+ ]
+ )
+ )
+ > 0,
+ "%s(%s), %s(%s)"
+ % (col.name, col.type, cols[i]["name"], ctype),
+ )
if not col.primary_key:
- assert cols[i]['default'] is None
+ assert cols[i]["default"] is None
@testing.requires.table_reflection
def test_get_columns(self):
@@ -417,24 +461,20 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.provide_metadata
def _type_round_trip(self, *types):
- t = Table('t', self.metadata,
- *[
- Column('t%d' % i, type_)
- for i, type_ in enumerate(types)
- ]
- )
+ t = Table(
+ "t",
+ self.metadata,
+ *[Column("t%d" % i, type_) for i, type_ in enumerate(types)]
+ )
t.create()
return [
- c['type'] for c in
- inspect(self.metadata.bind).get_columns('t')
+ c["type"] for c in inspect(self.metadata.bind).get_columns("t")
]
@testing.requires.table_reflection
def test_numeric_reflection(self):
- for typ in self._type_round_trip(
- sql_types.Numeric(18, 5),
- ):
+ for typ in self._type_round_trip(sql_types.Numeric(18, 5)):
assert isinstance(typ, sql_types.Numeric)
eq_(typ.precision, 18)
eq_(typ.scale, 5)
@@ -448,16 +488,19 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.table_reflection
@testing.provide_metadata
def test_nullable_reflection(self):
- t = Table('t', self.metadata,
- Column('a', Integer, nullable=True),
- Column('b', Integer, nullable=False))
+ t = Table(
+ "t",
+ self.metadata,
+ Column("a", Integer, nullable=True),
+ Column("b", Integer, nullable=False),
+ )
t.create()
eq_(
dict(
- (col['name'], col['nullable'])
- for col in inspect(self.metadata.bind).get_columns('t')
+ (col["name"], col["nullable"])
+ for col in inspect(self.metadata.bind).get_columns("t")
),
- {"a": True, "b": False}
+ {"a": True, "b": False},
)
@testing.requires.table_reflection
@@ -470,32 +513,30 @@ class ComponentReflectionTest(fixtures.TablesTest):
meta = MetaData(self.bind)
user_tmp = self.tables.user_tmp
insp = inspect(meta.bind)
- cols = insp.get_columns('user_tmp')
+ cols = insp.get_columns("user_tmp")
self.assert_(len(cols) > 0, len(cols))
for i, col in enumerate(user_tmp.columns):
- eq_(col.name, cols[i]['name'])
+ eq_(col.name, cols[i]["name"])
@testing.requires.temp_table_reflection
@testing.requires.view_column_reflection
@testing.requires.temporary_views
def test_get_temp_view_columns(self):
insp = inspect(self.bind)
- cols = insp.get_columns('user_tmp_v')
- eq_(
- [col['name'] for col in cols],
- ['id', 'name', 'foo']
- )
+ cols = insp.get_columns("user_tmp_v")
+ eq_([col["name"] for col in cols], ["id", "name", "foo"])
@testing.requires.view_column_reflection
def test_get_view_columns(self):
- self._test_get_columns(table_type='view')
+ self._test_get_columns(table_type="view")
@testing.requires.view_column_reflection
@testing.requires.schemas
def test_get_view_columns_with_schema(self):
self._test_get_columns(
- schema=testing.config.test_schema, table_type='view')
+ schema=testing.config.test_schema, table_type="view"
+ )
@testing.provide_metadata
def _test_get_pk_constraint(self, schema=None):
@@ -504,15 +545,15 @@ class ComponentReflectionTest(fixtures.TablesTest):
insp = inspect(meta.bind)
users_cons = insp.get_pk_constraint(users.name, schema=schema)
- users_pkeys = users_cons['constrained_columns']
- eq_(users_pkeys, ['user_id'])
+ users_pkeys = users_cons["constrained_columns"]
+ eq_(users_pkeys, ["user_id"])
addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
- addr_pkeys = addr_cons['constrained_columns']
- eq_(addr_pkeys, ['address_id'])
+ addr_pkeys = addr_cons["constrained_columns"]
+ eq_(addr_pkeys, ["address_id"])
with testing.requires.reflects_pk_names.fail_if():
- eq_(addr_cons['name'], 'email_ad_pk')
+ eq_(addr_cons["name"], "email_ad_pk")
@testing.requires.primary_key_constraint_reflection
def test_get_pk_constraint(self):
@@ -534,44 +575,46 @@ class ComponentReflectionTest(fixtures.TablesTest):
sa_exc.SADeprecationWarning,
"Call to deprecated method get_primary_keys."
" Use get_pk_constraint instead.",
- insp.get_primary_keys, users.name
+ insp.get_primary_keys,
+ users.name,
)
@testing.provide_metadata
def _test_get_foreign_keys(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
expected_schema = schema
# users
if testing.requires.self_referential_foreign_keys.enabled:
- users_fkeys = insp.get_foreign_keys(users.name,
- schema=schema)
+ users_fkeys = insp.get_foreign_keys(users.name, schema=schema)
fkey1 = users_fkeys[0]
with testing.requires.named_constraints.fail_if():
- eq_(fkey1['name'], "user_id_fk")
+ eq_(fkey1["name"], "user_id_fk")
- eq_(fkey1['referred_schema'], expected_schema)
- eq_(fkey1['referred_table'], users.name)
- eq_(fkey1['referred_columns'], ['user_id', ])
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
if testing.requires.self_referential_foreign_keys.enabled:
- eq_(fkey1['constrained_columns'], ['parent_user_id'])
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
# addresses
- addr_fkeys = insp.get_foreign_keys(addresses.name,
- schema=schema)
+ addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
fkey1 = addr_fkeys[0]
with testing.requires.implicitly_named_constraints.fail_if():
- self.assert_(fkey1['name'] is not None)
+ self.assert_(fkey1["name"] is not None)
- eq_(fkey1['referred_schema'], expected_schema)
- eq_(fkey1['referred_table'], users.name)
- eq_(fkey1['referred_columns'], ['user_id', ])
- eq_(fkey1['constrained_columns'], ['remote_user_id'])
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ eq_(fkey1["constrained_columns"], ["remote_user_id"])
@testing.requires.foreign_key_constraint_reflection
def test_get_foreign_keys(self):
@@ -586,9 +629,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schemas
def test_get_inter_schema_foreign_keys(self):
local_table, remote_table, remote_table_2 = self.tables(
- '%s.local_table' % testing.db.dialect.default_schema_name,
- '%s.remote_table' % testing.config.test_schema,
- '%s.remote_table_2' % testing.config.test_schema
+ "%s.local_table" % testing.db.dialect.default_schema_name,
+ "%s.remote_table" % testing.config.test_schema,
+ "%s.remote_table_2" % testing.config.test_schema,
)
insp = inspect(config.db)
@@ -597,25 +640,25 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(len(local_fkeys), 1)
fkey1 = local_fkeys[0]
- eq_(fkey1['referred_schema'], testing.config.test_schema)
- eq_(fkey1['referred_table'], remote_table_2.name)
- eq_(fkey1['referred_columns'], ['id', ])
- eq_(fkey1['constrained_columns'], ['remote_id'])
+ eq_(fkey1["referred_schema"], testing.config.test_schema)
+ eq_(fkey1["referred_table"], remote_table_2.name)
+ eq_(fkey1["referred_columns"], ["id"])
+ eq_(fkey1["constrained_columns"], ["remote_id"])
remote_fkeys = insp.get_foreign_keys(
- remote_table.name, schema=testing.config.test_schema)
+ remote_table.name, schema=testing.config.test_schema
+ )
eq_(len(remote_fkeys), 1)
fkey2 = remote_fkeys[0]
- assert fkey2['referred_schema'] in (
+ assert fkey2["referred_schema"] in (
None,
- testing.db.dialect.default_schema_name
+ testing.db.dialect.default_schema_name,
)
- eq_(fkey2['referred_table'], local_table.name)
- eq_(fkey2['referred_columns'], ['id', ])
- eq_(fkey2['constrained_columns'], ['local_id'])
-
+ eq_(fkey2["referred_table"], local_table.name)
+ eq_(fkey2["referred_columns"], ["id"])
+ eq_(fkey2["constrained_columns"], ["local_id"])
@testing.requires.foreign_key_constraint_option_reflection_ondelete
def test_get_foreign_key_options_ondelete(self):
@@ -630,26 +673,32 @@ class ComponentReflectionTest(fixtures.TablesTest):
meta = self.metadata
Table(
- 'x', meta,
- Column('id', Integer, primary_key=True),
- test_needs_fk=True
- )
-
- Table('table', meta,
- Column('id', Integer, primary_key=True),
- Column('x_id', Integer, sa.ForeignKey('x.id', name='xid')),
- Column('test', String(10)),
- test_needs_fk=True)
-
- Table('user', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('tid', Integer),
- sa.ForeignKeyConstraint(
- ['tid'], ['table.id'],
- name='myfk',
- **options),
- test_needs_fk=True)
+ "x",
+ meta,
+ Column("id", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "table",
+ meta,
+ Column("id", Integer, primary_key=True),
+ Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")),
+ Column("test", String(10)),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "user",
+ meta,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ sa.ForeignKeyConstraint(
+ ["tid"], ["table.id"], name="myfk", **options
+ ),
+ test_needs_fk=True,
+ )
meta.create_all()
@@ -657,49 +706,44 @@ class ComponentReflectionTest(fixtures.TablesTest):
# test 'options' is always present for a backend
# that can reflect these, since alembic looks for this
- opts = insp.get_foreign_keys('table')[0]['options']
+ opts = insp.get_foreign_keys("table")[0]["options"]
- eq_(
- dict(
- (k, opts[k])
- for k in opts if opts[k]
- ),
- {}
- )
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), {})
- opts = insp.get_foreign_keys('user')[0]['options']
- eq_(
- dict(
- (k, opts[k])
- for k in opts if opts[k]
- ),
- options
- )
+ opts = insp.get_foreign_keys("user")[0]["options"]
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), options)
def _assert_insp_indexes(self, indexes, expected_indexes):
- index_names = [d['name'] for d in indexes]
+ index_names = [d["name"] for d in indexes]
for e_index in expected_indexes:
- assert e_index['name'] in index_names
- index = indexes[index_names.index(e_index['name'])]
+ assert e_index["name"] in index_names
+ index = indexes[index_names.index(e_index["name"])]
for key in e_index:
eq_(e_index[key], index[key])
@testing.provide_metadata
def _test_get_indexes(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
# The database may decide to create indexes for foreign keys, etc.
# so there may be more indexes than expected.
insp = inspect(meta.bind)
- indexes = insp.get_indexes('users', schema=schema)
+ indexes = insp.get_indexes("users", schema=schema)
expected_indexes = [
- {'unique': False,
- 'column_names': ['test1', 'test2'],
- 'name': 'users_t_idx'},
- {'unique': False,
- 'column_names': ['user_id', 'test2', 'test1'],
- 'name': 'users_all_idx'}
+ {
+ "unique": False,
+ "column_names": ["test1", "test2"],
+ "name": "users_t_idx",
+ },
+ {
+ "unique": False,
+ "column_names": ["user_id", "test2", "test1"],
+ "name": "users_all_idx",
+ },
]
self._assert_insp_indexes(indexes, expected_indexes)
@@ -721,10 +765,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
# reflecting an index that has "x DESC" in it as the column.
# the DB may or may not give us "x", but make sure we get the index
# back, it has a name, it's connected to the table.
- expected_indexes = [
- {'unique': False,
- 'name': ixname}
- ]
+ expected_indexes = [{"unique": False, "name": ixname}]
self._assert_insp_indexes(indexes, expected_indexes)
t = Table(tname, meta, autoload_with=meta.bind)
@@ -748,24 +789,30 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.unique_constraint_reflection
def test_get_temp_table_unique_constraints(self):
insp = inspect(self.bind)
- reflected = insp.get_unique_constraints('user_tmp')
+ reflected = insp.get_unique_constraints("user_tmp")
for refl in reflected:
# Different dialects handle duplicate index and constraints
# differently, so ignore this flag
- refl.pop('duplicates_index', None)
- eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}])
+ refl.pop("duplicates_index", None)
+ eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}])
@testing.requires.temp_table_reflection
def test_get_temp_table_indexes(self):
insp = inspect(self.bind)
- indexes = insp.get_indexes('user_tmp')
+ indexes = insp.get_indexes("user_tmp")
for ind in indexes:
- ind.pop('dialect_options', None)
+ ind.pop("dialect_options", None)
eq_(
# TODO: we need to add better filtering for indexes/uq constraints
# that are doubled up
- [idx for idx in indexes if idx['name'] == 'user_tmp_ix'],
- [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}]
+ [idx for idx in indexes if idx["name"] == "user_tmp_ix"],
+ [
+ {
+ "unique": False,
+ "column_names": ["foo"],
+ "name": "user_tmp_ix",
+ }
+ ],
)
@testing.requires.unique_constraint_reflection
@@ -783,36 +830,37 @@ class ComponentReflectionTest(fixtures.TablesTest):
# CREATE TABLE?
uniques = sorted(
[
- {'name': 'unique_a', 'column_names': ['a']},
- {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']},
- {'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']},
- {'name': 'unique_asc_key', 'column_names': ['asc', 'key']},
- {'name': 'i.have.dots', 'column_names': ['b']},
- {'name': 'i have spaces', 'column_names': ['c']},
+ {"name": "unique_a", "column_names": ["a"]},
+ {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]},
+ {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]},
+ {"name": "unique_asc_key", "column_names": ["asc", "key"]},
+ {"name": "i.have.dots", "column_names": ["b"]},
+ {"name": "i have spaces", "column_names": ["c"]},
],
- key=operator.itemgetter('name')
+ key=operator.itemgetter("name"),
)
orig_meta = self.metadata
table = Table(
- 'testtbl', orig_meta,
- Column('a', sa.String(20)),
- Column('b', sa.String(30)),
- Column('c', sa.Integer),
+ "testtbl",
+ orig_meta,
+ Column("a", sa.String(20)),
+ Column("b", sa.String(30)),
+ Column("c", sa.Integer),
# reserved identifiers
- Column('asc', sa.String(30)),
- Column('key', sa.String(30)),
- schema=schema
+ Column("asc", sa.String(30)),
+ Column("key", sa.String(30)),
+ schema=schema,
)
for uc in uniques:
table.append_constraint(
- sa.UniqueConstraint(*uc['column_names'], name=uc['name'])
+ sa.UniqueConstraint(*uc["column_names"], name=uc["name"])
)
orig_meta.create_all()
inspector = inspect(orig_meta.bind)
reflected = sorted(
- inspector.get_unique_constraints('testtbl', schema=schema),
- key=operator.itemgetter('name')
+ inspector.get_unique_constraints("testtbl", schema=schema),
+ key=operator.itemgetter("name"),
)
names_that_duplicate_index = set()
@@ -820,25 +868,31 @@ class ComponentReflectionTest(fixtures.TablesTest):
for orig, refl in zip(uniques, reflected):
# Different dialects handle duplicate index and constraints
# differently, so ignore this flag
- dupe = refl.pop('duplicates_index', None)
+ dupe = refl.pop("duplicates_index", None)
if dupe:
names_that_duplicate_index.add(dupe)
eq_(orig, refl)
reflected_metadata = MetaData()
reflected = Table(
- 'testtbl', reflected_metadata, autoload_with=orig_meta.bind,
- schema=schema)
+ "testtbl",
+ reflected_metadata,
+ autoload_with=orig_meta.bind,
+ schema=schema,
+ )
# test "deduplicates for index" logic. MySQL and Oracle
# "unique constraints" are actually unique indexes (with possible
# exception of a unique that is a dupe of another one in the case
# of Oracle). make sure # they aren't duplicated.
idx_names = set([idx.name for idx in reflected.indexes])
- uq_names = set([
- uq.name for uq in reflected.constraints
- if isinstance(uq, sa.UniqueConstraint)]).difference(
- ['unique_c_a_b'])
+ uq_names = set(
+ [
+ uq.name
+ for uq in reflected.constraints
+ if isinstance(uq, sa.UniqueConstraint)
+ ]
+ ).difference(["unique_c_a_b"])
assert not idx_names.intersection(uq_names)
if names_that_duplicate_index:
@@ -858,47 +912,52 @@ class ComponentReflectionTest(fixtures.TablesTest):
def _test_get_check_constraints(self, schema=None):
orig_meta = self.metadata
Table(
- 'sa_cc', orig_meta,
- Column('a', Integer()),
- sa.CheckConstraint('a > 1 AND a < 5', name='cc1'),
- sa.CheckConstraint('a = 1 OR (a > 2 AND a < 5)', name='cc2'),
- schema=schema
+ "sa_cc",
+ orig_meta,
+ Column("a", Integer()),
+ sa.CheckConstraint("a > 1 AND a < 5", name="cc1"),
+ sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"),
+ schema=schema,
)
orig_meta.create_all()
inspector = inspect(orig_meta.bind)
reflected = sorted(
- inspector.get_check_constraints('sa_cc', schema=schema),
- key=operator.itemgetter('name')
+ inspector.get_check_constraints("sa_cc", schema=schema),
+ key=operator.itemgetter("name"),
)
# trying to minimize effect of quoting, parenthesis, etc.
# may need to add more to this as new dialects get CHECK
# constraint reflection support
def normalize(sqltext):
- return " ".join(re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I))
+ return " ".join(
+ re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)
+ )
reflected = [
- {"name": item["name"],
- "sqltext": normalize(item["sqltext"])}
+ {"name": item["name"], "sqltext": normalize(item["sqltext"])}
for item in reflected
]
eq_(
reflected,
[
- {'name': 'cc1', 'sqltext': 'a > 1 and a < 5'},
- {'name': 'cc2', 'sqltext': 'a = 1 or a > 2 and a < 5'}
- ]
+ {"name": "cc1", "sqltext": "a > 1 and a < 5"},
+ {"name": "cc2", "sqltext": "a = 1 or a > 2 and a < 5"},
+ ],
)
@testing.provide_metadata
def _test_get_view_definition(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
- view_name1 = 'users_v'
- view_name2 = 'email_addresses_v'
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
+ view_name1 = "users_v"
+ view_name2 = "email_addresses_v"
insp = inspect(meta.bind)
v1 = insp.get_view_definition(view_name1, schema=schema)
self.assert_(v1)
@@ -918,18 +977,21 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.provide_metadata
def _test_get_table_oid(self, table_name, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
oid = insp.get_table_oid(table_name, schema)
self.assert_(isinstance(oid, int))
def test_get_table_oid(self):
- self._test_get_table_oid('users')
+ self._test_get_table_oid("users")
@testing.requires.schemas
def test_get_table_oid_with_schema(self):
- self._test_get_table_oid('users', schema=testing.config.test_schema)
+ self._test_get_table_oid("users", schema=testing.config.test_schema)
@testing.requires.table_reflection
@testing.provide_metadata
@@ -950,49 +1012,53 @@ class ComponentReflectionTest(fixtures.TablesTest):
insp = inspect(meta.bind)
for tname, cname in [
- ('users', 'user_id'),
- ('email_addresses', 'address_id'),
- ('dingalings', 'dingaling_id'),
+ ("users", "user_id"),
+ ("email_addresses", "address_id"),
+ ("dingalings", "dingaling_id"),
]:
cols = insp.get_columns(tname)
- id_ = {c['name']: c for c in cols}[cname]
- assert id_.get('autoincrement', True)
+ id_ = {c["name"]: c for c in cols}[cname]
+ assert id_.get("autoincrement", True)
class NormalizedNameTest(fixtures.TablesTest):
- __requires__ = 'denormalized_names',
+ __requires__ = ("denormalized_names",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
- quoted_name('t1', quote=True), metadata,
- Column('id', Integer, primary_key=True),
+ quoted_name("t1", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
)
Table(
- quoted_name('t2', quote=True), metadata,
- Column('id', Integer, primary_key=True),
- Column('t1id', ForeignKey('t1.id'))
+ quoted_name("t2", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("t1id", ForeignKey("t1.id")),
)
def test_reflect_lowercase_forced_tables(self):
m2 = MetaData(testing.db)
- t2_ref = Table(quoted_name('t2', quote=True), m2, autoload=True)
- t1_ref = m2.tables['t1']
+ t2_ref = Table(quoted_name("t2", quote=True), m2, autoload=True)
+ t1_ref = m2.tables["t1"]
assert t2_ref.c.t1id.references(t1_ref.c.id)
m3 = MetaData(testing.db)
- m3.reflect(only=lambda name, m: name.lower() in ('t1', 't2'))
- assert m3.tables['t2'].c.t1id.references(m3.tables['t1'].c.id)
+ m3.reflect(only=lambda name, m: name.lower() in ("t1", "t2"))
+ assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id)
def test_get_table_names(self):
tablenames = [
- t for t in inspect(testing.db).get_table_names()
- if t.lower() in ("t1", "t2")]
+ t
+ for t in inspect(testing.db).get_table_names()
+ if t.lower() in ("t1", "t2")
+ ]
eq_(tablenames[0].upper(), tablenames[0].lower())
eq_(tablenames[1].upper(), tablenames[1].lower())
-__all__ = ('ComponentReflectionTest', 'HasTableTest', 'NormalizedNameTest')
+__all__ = ("ComponentReflectionTest", "HasTableTest", "NormalizedNameTest")
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index f464d47eb..247f05cf5 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -15,14 +15,18 @@ class RowFetchTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('plain_pk', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
- Table('has_dates', metadata,
- Column('id', Integer, primary_key=True),
- Column('today', DateTime)
- )
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Table(
+ "has_dates",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("today", DateTime),
+ )
@classmethod
def insert_data(cls):
@@ -32,65 +36,51 @@ class RowFetchTest(fixtures.TablesTest):
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
- ]
+ ],
)
config.db.execute(
cls.tables.has_dates.insert(),
- [
- {"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}
- ]
+ [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
)
def test_via_string(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row['id'], 1
- )
- eq_(
- row['data'], "d1"
- )
+ eq_(row["id"], 1)
+ eq_(row["data"], "d1")
def test_via_int(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row[0], 1
- )
- eq_(
- row[1], "d1"
- )
+ eq_(row[0], 1)
+ eq_(row[1], "d1")
def test_via_col_object(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row[self.tables.plain_pk.c.id], 1
- )
- eq_(
- row[self.tables.plain_pk.c.data], "d1"
- )
+ eq_(row[self.tables.plain_pk.c.id], 1)
+ eq_(row[self.tables.plain_pk.c.data], "d1")
@requirements.duplicate_names_in_cursor_description
def test_row_with_dupe_names(self):
result = config.db.execute(
- select([self.tables.plain_pk.c.data,
- self.tables.plain_pk.c.data.label('data')]).
- order_by(self.tables.plain_pk.c.id)
+ select(
+ [
+ self.tables.plain_pk.c.data,
+ self.tables.plain_pk.c.data.label("data"),
+ ]
+ ).order_by(self.tables.plain_pk.c.id)
)
row = result.first()
- eq_(result.keys(), ['data', 'data'])
- eq_(row, ('d1', 'd1'))
+ eq_(result.keys(), ["data", "data"])
+ eq_(row, ("d1", "d1"))
def test_row_w_scalar_select(self):
"""test that a scalar select as a column is returned as such
@@ -101,11 +91,11 @@ class RowFetchTest(fixtures.TablesTest):
"""
datetable = self.tables.has_dates
- s = select([datetable.alias('x').c.today]).as_scalar()
- s2 = select([datetable.c.id, s.label('somelabel')])
+ s = select([datetable.alias("x").c.today]).as_scalar()
+ s2 = select([datetable.c.id, s.label("somelabel")])
row = config.db.execute(s2).first()
- eq_(row['somelabel'], datetime.datetime(2006, 5, 12, 12, 0, 0))
+ eq_(row["somelabel"], datetime.datetime(2006, 5, 12, 12, 0, 0))
class PercentSchemaNamesTest(fixtures.TablesTest):
@@ -117,29 +107,31 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
"""
- __requires__ = ('percent_schema_names', )
+ __requires__ = ("percent_schema_names",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- cls.tables.percent_table = Table('percent%table', metadata,
- Column("percent%", Integer),
- Column(
- "spaces % more spaces", Integer),
- )
+ cls.tables.percent_table = Table(
+ "percent%table",
+ metadata,
+ Column("percent%", Integer),
+ Column("spaces % more spaces", Integer),
+ )
cls.tables.lightweight_percent_table = sql.table(
- 'percent%table', sql.column("percent%"),
- sql.column("spaces % more spaces")
+ "percent%table",
+ sql.column("percent%"),
+ sql.column("spaces % more spaces"),
)
def test_single_roundtrip(self):
percent_table = self.tables.percent_table
for params in [
- {'percent%': 5, 'spaces % more spaces': 12},
- {'percent%': 7, 'spaces % more spaces': 11},
- {'percent%': 9, 'spaces % more spaces': 10},
- {'percent%': 11, 'spaces % more spaces': 9}
+ {"percent%": 5, "spaces % more spaces": 12},
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
]:
config.db.execute(percent_table.insert(), params)
self._assert_table()
@@ -147,14 +139,15 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
def test_executemany_roundtrip(self):
percent_table = self.tables.percent_table
config.db.execute(
- percent_table.insert(),
- {'percent%': 5, 'spaces % more spaces': 12}
+ percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
)
config.db.execute(
percent_table.insert(),
- [{'percent%': 7, 'spaces % more spaces': 11},
- {'percent%': 9, 'spaces % more spaces': 10},
- {'percent%': 11, 'spaces % more spaces': 9}]
+ [
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ],
)
self._assert_table()
@@ -163,85 +156,81 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
lightweight_percent_table = self.tables.lightweight_percent_table
for table in (
- percent_table,
- percent_table.alias(),
- lightweight_percent_table,
- lightweight_percent_table.alias()):
+ percent_table,
+ percent_table.alias(),
+ lightweight_percent_table,
+ lightweight_percent_table.alias(),
+ ):
eq_(
list(
config.db.execute(
- table.select().order_by(table.c['percent%'])
+ table.select().order_by(table.c["percent%"])
)
),
- [
- (5, 12),
- (7, 11),
- (9, 10),
- (11, 9)
- ]
+ [(5, 12), (7, 11), (9, 10), (11, 9)],
)
eq_(
list(
config.db.execute(
- table.select().
- where(table.c['spaces % more spaces'].in_([9, 10])).
- order_by(table.c['percent%']),
+ table.select()
+ .where(table.c["spaces % more spaces"].in_([9, 10]))
+ .order_by(table.c["percent%"])
)
),
- [
- (9, 10),
- (11, 9)
- ]
+ [(9, 10), (11, 9)],
)
- row = config.db.execute(table.select().
- order_by(table.c['percent%'])).first()
- eq_(row['percent%'], 5)
- eq_(row['spaces % more spaces'], 12)
+ row = config.db.execute(
+ table.select().order_by(table.c["percent%"])
+ ).first()
+ eq_(row["percent%"], 5)
+ eq_(row["spaces % more spaces"], 12)
- eq_(row[table.c['percent%']], 5)
- eq_(row[table.c['spaces % more spaces']], 12)
+ eq_(row[table.c["percent%"]], 5)
+ eq_(row[table.c["spaces % more spaces"]], 12)
config.db.execute(
percent_table.update().values(
- {percent_table.c['spaces % more spaces']: 15}
+ {percent_table.c["spaces % more spaces"]: 15}
)
)
eq_(
list(
config.db.execute(
- percent_table.
- select().
- order_by(percent_table.c['percent%'])
+ percent_table.select().order_by(
+ percent_table.c["percent%"]
+ )
)
),
- [(5, 15), (7, 15), (9, 15), (11, 15)]
+ [(5, 15), (7, 15), (9, 15), (11, 15)],
)
-class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
+class ServerSideCursorsTest(
+ fixtures.TestBase, testing.AssertsExecutionResults
+):
- __requires__ = ('server_side_cursors', )
+ __requires__ = ("server_side_cursors",)
__backend__ = True
def _is_server_side(self, cursor):
if self.engine.dialect.driver == "psycopg2":
return cursor.name
- elif self.engine.dialect.driver == 'pymysql':
- sscursor = __import__('pymysql.cursors').cursors.SSCursor
+ elif self.engine.dialect.driver == "pymysql":
+ sscursor = __import__("pymysql.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver == "mysqldb":
- sscursor = __import__('MySQLdb.cursors').cursors.SSCursor
+ sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
else:
return False
def _fixture(self, server_side_cursors):
self.engine = engines.testing_engine(
- options={'server_side_cursors': server_side_cursors}
+ options={"server_side_cursors": server_side_cursors}
)
return self.engine
@@ -251,12 +240,12 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_global_string(self):
engine = self._fixture(True)
- result = engine.execute('select 1')
+ result = engine.execute("select 1")
assert self._is_server_side(result.cursor)
def test_global_text(self):
engine = self._fixture(True)
- result = engine.execute(text('select 1'))
+ result = engine.execute(text("select 1"))
assert self._is_server_side(result.cursor)
def test_global_expr(self):
@@ -266,7 +255,7 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_global_off_explicit(self):
engine = self._fixture(False)
- result = engine.execute(text('select 1'))
+ result = engine.execute(text("select 1"))
# It should be off globally ...
@@ -286,10 +275,11 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
engine = self._fixture(False)
# and this one
- result = \
- engine.connect().execution_options(stream_results=True).\
- execute('select 1'
- )
+ result = (
+ engine.connect()
+ .execution_options(stream_results=True)
+ .execute("select 1")
+ )
assert self._is_server_side(result.cursor)
def test_stmt_enabled_conn_option_disabled(self):
@@ -298,9 +288,9 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
s = select([1]).execution_options(stream_results=True)
# not this one
- result = \
- engine.connect().execution_options(stream_results=False).\
- execute(s)
+ result = (
+ engine.connect().execution_options(stream_results=False).execute(s)
+ )
assert not self._is_server_side(result.cursor)
def test_stmt_option_disabled(self):
@@ -329,18 +319,18 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_for_update_string(self):
engine = self._fixture(True)
- result = engine.execute('SELECT 1 FOR UPDATE')
+ result = engine.execute("SELECT 1 FOR UPDATE")
assert self._is_server_side(result.cursor)
def test_text_no_ss(self):
engine = self._fixture(False)
- s = text('select 42')
+ s = text("select 42")
result = engine.execute(s)
assert not self._is_server_side(result.cursor)
def test_text_ss_option(self):
engine = self._fixture(False)
- s = text('select 42').execution_options(stream_results=True)
+ s = text("select 42").execution_options(stream_results=True)
result = engine.execute(s)
assert self._is_server_side(result.cursor)
@@ -349,19 +339,25 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
md = self.metadata
engine = self._fixture(True)
- test_table = Table('test_table', md,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)))
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
test_table.create(checkfirst=True)
- test_table.insert().execute(data='data1')
- test_table.insert().execute(data='data2')
- eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(),
- [(1, 'data1'), (2, 'data2')])
- test_table.update().where(
- test_table.c.id == 2).values(
- data=test_table.c.data +
- ' updated').execute()
- eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(),
- [(1, 'data1'), (2, 'data2 updated')])
+ test_table.insert().execute(data="data1")
+ test_table.insert().execute(data="data2")
+ eq_(
+ test_table.select().order_by(test_table.c.id).execute().fetchall(),
+ [(1, "data1"), (2, "data2")],
+ )
+ test_table.update().where(test_table.c.id == 2).values(
+ data=test_table.c.data + " updated"
+ ).execute()
+ eq_(
+ test_table.select().order_by(test_table.c.id).execute().fetchall(),
+ [(1, "data1"), (2, "data2 updated")],
+ )
test_table.delete().execute()
- eq_(select([func.count('*')]).select_from(test_table).scalar(), 0)
+ eq_(select([func.count("*")]).select_from(test_table).scalar(), 0)
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index 73ce02492..032b68eb6 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -16,10 +16,12 @@ class CollateTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(100))
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(100)),
+ )
@classmethod
def insert_data(cls):
@@ -28,26 +30,21 @@ class CollateTest(fixtures.TablesTest):
[
{"id": 1, "data": "collate data1"},
{"id": 2, "data": "collate data2"},
- ]
+ ],
)
def _assert_result(self, select, result):
- eq_(
- config.db.execute(select).fetchall(),
- result
- )
+ eq_(config.db.execute(select).fetchall(), result)
@testing.requires.order_by_collation
def test_collate_order_by(self):
collation = testing.requires.get_order_by_collation(testing.config)
self._assert_result(
- select([self.tables.some_table]).
- order_by(self.tables.some_table.c.data.collate(collation).asc()),
- [
- (1, "collate data1"),
- (2, "collate data2"),
- ]
+ select([self.tables.some_table]).order_by(
+ self.tables.some_table.c.data.collate(collation).asc()
+ ),
+ [(1, "collate data1"), (2, "collate data2")],
)
@@ -59,17 +56,20 @@ class OrderByLabelTest(fixtures.TablesTest):
setting.
"""
+
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer),
- Column('q', String(50)),
- Column('p', String(50))
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("q", String(50)),
+ Column("p", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -79,65 +79,55 @@ class OrderByLabelTest(fixtures.TablesTest):
{"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
{"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
{"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
- ]
+ ],
)
def _assert_result(self, select, result):
- eq_(
- config.db.execute(select).fetchall(),
- result
- )
+ eq_(config.db.execute(select).fetchall(), result)
def test_plain(self):
table = self.tables.some_table
- lx = table.c.x.label('lx')
- self._assert_result(
- select([lx]).order_by(lx),
- [(1, ), (2, ), (3, )]
- )
+ lx = table.c.x.label("lx")
+ self._assert_result(select([lx]).order_by(lx), [(1,), (2,), (3,)])
def test_composed_int(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
- self._assert_result(
- select([lx]).order_by(lx),
- [(3, ), (5, ), (7, )]
- )
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select([lx]).order_by(lx), [(3,), (5,), (7,)])
def test_composed_multiple(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
- ly = (func.lower(table.c.q) + table.c.p).label('ly')
+ lx = (table.c.x + table.c.y).label("lx")
+ ly = (func.lower(table.c.q) + table.c.p).label("ly")
self._assert_result(
select([lx, ly]).order_by(lx, ly.desc()),
- [(3, util.u('q1p3')), (5, util.u('q2p2')), (7, util.u('q3p1'))]
+ [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))],
)
def test_plain_desc(self):
table = self.tables.some_table
- lx = table.c.x.label('lx')
+ lx = table.c.x.label("lx")
self._assert_result(
- select([lx]).order_by(lx.desc()),
- [(3, ), (2, ), (1, )]
+ select([lx]).order_by(lx.desc()), [(3,), (2,), (1,)]
)
def test_composed_int_desc(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
+ lx = (table.c.x + table.c.y).label("lx")
self._assert_result(
- select([lx]).order_by(lx.desc()),
- [(7, ), (5, ), (3, )]
+ select([lx]).order_by(lx.desc()), [(7,), (5,), (3,)]
)
@testing.requires.group_by_complex_expression
def test_group_by_composed(self):
table = self.tables.some_table
- expr = (table.c.x + table.c.y).label('lx')
- stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr)
- self._assert_result(
- stmt,
- [(1, 3), (1, 5), (1, 7)]
+ expr = (table.c.x + table.c.y).label("lx")
+ stmt = (
+ select([func.count(table.c.id), expr])
+ .group_by(expr)
+ .order_by(expr)
)
+ self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)])
class LimitOffsetTest(fixtures.TablesTest):
@@ -145,10 +135,13 @@ class LimitOffsetTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -159,20 +152,17 @@ class LimitOffsetTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_simple_limit(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(2),
- [(1, 1, 2), (2, 2, 3)]
+ [(1, 1, 2), (2, 2, 3)],
)
@testing.requires.offset
@@ -180,7 +170,7 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).offset(2),
- [(3, 3, 4), (4, 4, 5)]
+ [(3, 3, 4), (4, 4, 5)],
)
@testing.requires.offset
@@ -188,7 +178,7 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(2).offset(1),
- [(2, 2, 3), (3, 3, 4)]
+ [(2, 2, 3), (3, 3, 4)],
)
@testing.requires.offset
@@ -198,41 +188,40 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
stmt = select([table]).order_by(table.c.id).limit(2).offset(1)
sql = stmt.compile(
- dialect=config.db.dialect,
- compile_kwargs={"literal_binds": True})
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
sql = str(sql)
- self._assert_result(
- sql,
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(sql, [(2, 2, 3), (3, 3, 4)])
@testing.requires.bound_limit_offset
def test_bound_limit(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).limit(bindparam('l')),
+ select([table]).order_by(table.c.id).limit(bindparam("l")),
[(1, 1, 2), (2, 2, 3)],
- params={"l": 2}
+ params={"l": 2},
)
@testing.requires.bound_limit_offset
def test_bound_offset(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).offset(bindparam('o')),
+ select([table]).order_by(table.c.id).offset(bindparam("o")),
[(3, 3, 4), (4, 4, 5)],
- params={"o": 2}
+ params={"o": 2},
)
@testing.requires.bound_limit_offset
def test_bound_limit_offset(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).
- limit(bindparam("l")).offset(bindparam("o")),
+ select([table])
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
[(2, 2, 3), (3, 3, 4)],
- params={"l": 2, "o": 1}
+ params={"l": 2, "o": 1},
)
@@ -241,10 +230,13 @@ class CompoundSelectTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -255,14 +247,11 @@ class CompoundSelectTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_plain_union(self):
table = self.tables.some_table
@@ -270,10 +259,7 @@ class CompoundSelectTest(fixtures.TablesTest):
s2 = select([table]).where(table.c.id == 3)
u1 = union(s1, s2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
def test_select_from_plain_union(self):
table = self.tables.some_table
@@ -281,80 +267,88 @@ class CompoundSelectTest(fixtures.TablesTest):
s2 = select([table]).where(table.c.id == 3)
u1 = union(s1, s2).alias().select()
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.order_by_col_from_union
@testing.requires.parens_in_union_contained_select_w_limit_offset
def test_limit_offset_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id)
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ )
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.parens_in_union_contained_select_wo_limit_offset
def test_order_by_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- order_by(table.c.id)
+ s1 = select([table]).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select([table]).where(table.c.id == 3).order_by(table.c.id)
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
def test_distinct_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- distinct()
- s2 = select([table]).where(table.c.id == 3).\
- distinct()
+ s1 = select([table]).where(table.c.id == 2).distinct()
+ s2 = select([table]).where(table.c.id == 3).distinct()
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.parens_in_union_contained_select_w_limit_offset
def test_limit_offset_in_unions_from_alias(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id)
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ )
# this necessarily has double parens
u1 = union(s1, s2).alias()
self._assert_result(
- u1.select().limit(2).order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
+ u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
def test_limit_offset_aliased_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id).alias().select()
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id).alias().select()
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
class ExpandingBoundInTest(fixtures.TablesTest):
@@ -362,11 +356,14 @@ class ExpandingBoundInTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer),
- Column('z', String(50)))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -377,178 +374,184 @@ class ExpandingBoundInTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3, "z": "z2"},
{"id": 3, "x": 3, "y": 4, "z": "z3"},
{"id": 4, "x": 4, "y": 5, "z": "z4"},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_multiple_empty_sets(self):
# test that any anonymous aliasing used by the dialect
# is fine with duplicates
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).where(
- table.c.y.in_(bindparam('p', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": [], "p": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .where(table.c.y.in_(bindparam("p", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": [], "p": []})
+
@testing.requires.tuple_in
def test_empty_heterogeneous_tuples(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.z).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
@testing.requires.tuple_in
def test_empty_homogeneous_tuples(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.y).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_bound_in_scalar(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(2, ), (3, ), (4, )],
- params={"q": [2, 3, 4]},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
+
@testing.requires.tuple_in
def test_bound_in_two_tuple(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.y).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
+ )
self._assert_result(
- stmt,
- [(2, ), (3, ), (4, )],
- params={"q": [(2, 3), (3, 4), (4, 5)]},
+ stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
)
@testing.requires.tuple_in
def test_bound_in_heterogeneous_two_tuple(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.z).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
+ )
self._assert_result(
stmt,
- [(2, ), (3, ), (4, )],
+ [(2,), (3,), (4,)],
params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
)
def test_empty_set_against_integer(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_empty_set_against_integer_negation(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.notin_(bindparam('q', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(1, ), (2, ), (3, ), (4, )],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.notin_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
def test_empty_set_against_string(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.z.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.z.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_empty_set_against_string_negation(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.z.notin_(bindparam('q', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(1, ), (2, ), (3, ), (4, )],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.z.notin_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
def test_null_in_empty_set_is_false(self):
- stmt = select([
- case(
- [
- (
- null().in_(bindparam('foo', value=(), expanding=True)),
- true()
- )
- ],
- else_=false()
- )
- ])
- in_(
- config.db.execute(stmt).fetchone()[0],
- (False, 0)
+ stmt = select(
+ [
+ case(
+ [
+ (
+ null().in_(
+ bindparam("foo", value=(), expanding=True)
+ ),
+ true(),
+ )
+ ],
+ else_=false(),
+ )
+ ]
)
+ in_(config.db.execute(stmt).fetchone()[0], (False, 0))
class LikeFunctionsTest(fixtures.TablesTest):
__backend__ = True
- run_inserts = 'once'
+ run_inserts = "once"
run_deletes = None
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -565,7 +568,7 @@ class LikeFunctionsTest(fixtures.TablesTest):
{"id": 8, "data": "ab9cdefg"},
{"id": 9, "data": "abcde#fg"},
{"id": 10, "data": "abcd9fg"},
- ]
+ ],
)
def _test(self, expr, expected):
@@ -573,8 +576,10 @@ class LikeFunctionsTest(fixtures.TablesTest):
with config.db.connect() as conn:
rows = {
- value for value, in
- conn.execute(select([some_table.c.id]).where(expr))
+ value
+ for value, in conn.execute(
+ select([some_table.c.id]).where(expr)
+ )
}
eq_(rows, expected)
@@ -591,7 +596,8 @@ class LikeFunctionsTest(fixtures.TablesTest):
col = self.tables.some_table.c.data
self._test(
col.startswith(literal_column("'ab%c'")),
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ )
def test_startswith_escape(self):
col = self.tables.some_table.c.data
@@ -608,8 +614,9 @@ class LikeFunctionsTest(fixtures.TablesTest):
def test_endswith_sqlexpr(self):
col = self.tables.some_table.c.data
- self._test(col.endswith(literal_column("'e%fg'")),
- {1, 2, 3, 4, 5, 6, 7, 8, 9})
+ self._test(
+ col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9}
+ )
def test_endswith_autoescape(self):
col = self.tables.some_table.c.data
@@ -640,5 +647,3 @@ class LikeFunctionsTest(fixtures.TablesTest):
col = self.tables.some_table.c.data
self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
-
-
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
index f1c00de6b..15a850fe9 100644
--- a/lib/sqlalchemy/testing/suite/test_sequence.py
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -9,140 +9,144 @@ from ..schema import Table, Column
class SequenceTest(fixtures.TablesTest):
- __requires__ = ('sequences',)
+ __requires__ = ("sequences",)
__backend__ = True
- run_create_tables = 'each'
+ run_create_tables = "each"
@classmethod
def define_tables(cls, metadata):
- Table('seq_pk', metadata,
- Column('id', Integer, Sequence('tab_id_seq'), primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "seq_pk",
+ metadata,
+ Column("id", Integer, Sequence("tab_id_seq"), primary_key=True),
+ Column("data", String(50)),
+ )
- Table('seq_opt_pk', metadata,
- Column('id', Integer, Sequence('tab_id_seq', optional=True),
- primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "seq_opt_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq", optional=True),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
def test_insert_roundtrip(self):
- config.db.execute(
- self.tables.seq_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.seq_pk.insert(), data="some data")
self._assert_round_trip(self.tables.seq_pk, config.db)
def test_insert_lastrowid(self):
- r = config.db.execute(
- self.tables.seq_pk.insert(),
- data="some data"
- )
- eq_(
- r.inserted_primary_key,
- [1]
- )
+ r = config.db.execute(self.tables.seq_pk.insert(), data="some data")
+ eq_(r.inserted_primary_key, [1])
def test_nextval_direct(self):
- r = config.db.execute(
- self.tables.seq_pk.c.id.default
- )
- eq_(
- r, 1
- )
+ r = config.db.execute(self.tables.seq_pk.c.id.default)
+ eq_(r, 1)
@requirements.sequences_optional
def test_optional_seq(self):
r = config.db.execute(
- self.tables.seq_opt_pk.insert(),
- data="some data"
- )
- eq_(
- r.inserted_primary_key,
- [1]
+ self.tables.seq_opt_pk.insert(), data="some data"
)
+ eq_(r.inserted_primary_key, [1])
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (1, "some data")
- )
+ eq_(row, (1, "some data"))
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
- __requires__ = ('sequences',)
+ __requires__ = ("sequences",)
__backend__ = True
def test_literal_binds_inline_compile(self):
table = Table(
- 'x', MetaData(),
- Column('y', Integer, Sequence('y_seq')),
- Column('q', Integer))
+ "x",
+ MetaData(),
+ Column("y", Integer, Sequence("y_seq")),
+ Column("q", Integer),
+ )
stmt = table.insert().values(q=5)
seq_nextval = testing.db.dialect.statement_compiler(
- statement=None, dialect=testing.db.dialect).visit_sequence(
- Sequence("y_seq"))
+ statement=None, dialect=testing.db.dialect
+ ).visit_sequence(Sequence("y_seq"))
self.assert_compile(
stmt,
- "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ),
+ "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
literal_binds=True,
- dialect=testing.db.dialect)
+ dialect=testing.db.dialect,
+ )
class HasSequenceTest(fixtures.TestBase):
- __requires__ = 'sequences',
+ __requires__ = ("sequences",)
__backend__ = True
def test_has_sequence(self):
- s1 = Sequence('user_id_seq')
+ s1 = Sequence("user_id_seq")
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db,
- 'user_id_seq'), True)
+ eq_(
+ testing.db.dialect.has_sequence(testing.db, "user_id_seq"),
+ True,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_schema(self):
- s1 = Sequence('user_id_seq', schema=config.test_schema)
+ s1 = Sequence("user_id_seq", schema=config.test_schema)
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(
- testing.db, 'user_id_seq', schema=config.test_schema), True)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ True,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
def test_has_sequence_neg(self):
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
- False)
+ eq_(testing.db.dialect.has_sequence(testing.db, "user_id_seq"), False)
@testing.requires.schemas
def test_has_sequence_schemas_neg(self):
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
- schema=config.test_schema),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ False,
+ )
@testing.requires.schemas
def test_has_sequence_default_not_in_remote(self):
- s1 = Sequence('user_id_seq')
+ s1 = Sequence("user_id_seq")
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
- schema=config.test_schema),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ False,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self):
- s1 = Sequence('user_id_seq', schema=config.test_schema)
+ s1 = Sequence("user_id_seq", schema=config.test_schema)
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(testing.db, "user_id_seq"),
+ False,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 27c7bb115..6dfb80915 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -4,9 +4,24 @@ from .. import fixtures, config
from ..assertions import eq_
from ..config import requirements
from sqlalchemy import Integer, Unicode, UnicodeText, select, TIMESTAMP
-from sqlalchemy import Date, DateTime, Time, MetaData, String, \
- Text, Numeric, Float, literal, Boolean, cast, null, JSON, and_, \
- type_coerce, BigInteger
+from sqlalchemy import (
+ Date,
+ DateTime,
+ Time,
+ MetaData,
+ String,
+ Text,
+ Numeric,
+ Float,
+ literal,
+ Boolean,
+ cast,
+ null,
+ JSON,
+ and_,
+ type_coerce,
+ BigInteger,
+)
from ..schema import Table, Column
from ... import testing
import decimal
@@ -24,13 +39,17 @@ class _LiteralRoundTripFixture(object):
# into a typed column. we can then SELECT it back as its
# official type; ideally we'd be able to use CAST here
# but MySQL in particular can't CAST fully
- t = Table('t', self.metadata, Column('x', type_))
+ t = Table("t", self.metadata, Column("x", type_))
t.create()
for value in input_:
- ins = t.insert().values(x=literal(value)).compile(
- dialect=testing.db.dialect,
- compile_kwargs=dict(literal_binds=True)
+ ins = (
+ t.insert()
+ .values(x=literal(value))
+ .compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
)
testing.db.execute(ins)
@@ -42,40 +61,33 @@ class _LiteralRoundTripFixture(object):
class _UnicodeFixture(_LiteralRoundTripFixture):
- __requires__ = 'unicode_data',
+ __requires__ = ("unicode_data",)
- data = u("Alors vous imaginez ma surprise, au lever du jour, "
- "quand une drôle de petite voix m’a réveillé. Elle "
- "disait: « S’il vous plaît… dessine-moi un mouton! »")
+ data = u(
+ "Alors vous imaginez ma surprise, au lever du jour, "
+ "quand une drôle de petite voix m’a réveillé. Elle "
+ "disait: « S’il vous plaît… dessine-moi un mouton! »"
+ )
@classmethod
def define_tables(cls, metadata):
- Table('unicode_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('unicode_data', cls.datatype),
- )
+ Table(
+ "unicode_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("unicode_data", cls.datatype),
+ )
def test_round_trip(self):
unicode_table = self.tables.unicode_table
- config.db.execute(
- unicode_table.insert(),
- {
- 'unicode_data': self.data,
- }
- )
+ config.db.execute(unicode_table.insert(), {"unicode_data": self.data})
- row = config.db.execute(
- select([
- unicode_table.c.unicode_data,
- ])
- ).first()
+ row = config.db.execute(select([unicode_table.c.unicode_data])).first()
- eq_(
- row,
- (self.data, )
- )
+ eq_(row, (self.data,))
assert isinstance(row[0], util.text_type)
def test_round_trip_executemany(self):
@@ -83,44 +95,29 @@ class _UnicodeFixture(_LiteralRoundTripFixture):
config.db.execute(
unicode_table.insert(),
- [
- {
- 'unicode_data': self.data,
- }
- for i in range(3)
- ]
+ [{"unicode_data": self.data} for i in range(3)],
)
rows = config.db.execute(
- select([
- unicode_table.c.unicode_data,
- ])
+ select([unicode_table.c.unicode_data])
).fetchall()
- eq_(
- rows,
- [(self.data, ) for i in range(3)]
- )
+ eq_(rows, [(self.data,) for i in range(3)])
for row in rows:
assert isinstance(row[0], util.text_type)
def _test_empty_strings(self):
unicode_table = self.tables.unicode_table
- config.db.execute(
- unicode_table.insert(),
- {"unicode_data": u('')}
- )
- row = config.db.execute(
- select([unicode_table.c.unicode_data])
- ).first()
- eq_(row, (u(''),))
+ config.db.execute(unicode_table.insert(), {"unicode_data": u("")})
+ row = config.db.execute(select([unicode_table.c.unicode_data])).first()
+ eq_(row, (u(""),))
def test_literal(self):
self._literal_round_trip(self.datatype, [self.data], [self.data])
class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
- __requires__ = 'unicode_data',
+ __requires__ = ("unicode_data",)
__backend__ = True
datatype = Unicode(255)
@@ -131,7 +128,7 @@ class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
- __requires__ = 'unicode_data', 'text_type'
+ __requires__ = "unicode_data", "text_type"
__backend__ = True
datatype = UnicodeText()
@@ -142,54 +139,47 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
- __requires__ = 'text_type',
+ __requires__ = ("text_type",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('text_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('text_data', Text),
- )
+ Table(
+ "text_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("text_data", Text),
+ )
def test_text_roundtrip(self):
text_table = self.tables.text_table
- config.db.execute(
- text_table.insert(),
- {"text_data": 'some text'}
- )
- row = config.db.execute(
- select([text_table.c.text_data])
- ).first()
- eq_(row, ('some text',))
+ config.db.execute(text_table.insert(), {"text_data": "some text"})
+ row = config.db.execute(select([text_table.c.text_data])).first()
+ eq_(row, ("some text",))
def test_text_empty_strings(self):
text_table = self.tables.text_table
- config.db.execute(
- text_table.insert(),
- {"text_data": ''}
- )
- row = config.db.execute(
- select([text_table.c.text_data])
- ).first()
- eq_(row, ('',))
+ config.db.execute(text_table.insert(), {"text_data": ""})
+ row = config.db.execute(select([text_table.c.text_data])).first()
+ eq_(row, ("",))
def test_literal(self):
self._literal_round_trip(Text, ["some text"], ["some text"])
def test_literal_quoting(self):
- data = '''some 'text' hey "hi there" that's text'''
+ data = """some 'text' hey "hi there" that's text"""
self._literal_round_trip(Text, [data], [data])
def test_literal_backslashes(self):
- data = r'backslash one \ backslash two \\ end'
+ data = r"backslash one \ backslash two \\ end"
self._literal_round_trip(Text, [data], [data])
def test_literal_percentsigns(self):
- data = r'percent % signs %% percent'
+ data = r"percent % signs %% percent"
self._literal_round_trip(Text, [data], [data])
@@ -199,9 +189,7 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
@requirements.unbounded_varchar
def test_nolength_string(self):
metadata = MetaData()
- foo = Table('foo', metadata,
- Column('one', String)
- )
+ foo = Table("foo", metadata, Column("one", String))
foo.create(config.db)
foo.drop(config.db)
@@ -210,11 +198,11 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._literal_round_trip(String(40), ["some text"], ["some text"])
def test_literal_quoting(self):
- data = '''some 'text' hey "hi there" that's text'''
+ data = """some 'text' hey "hi there" that's text"""
self._literal_round_trip(String(40), [data], [data])
def test_literal_backslashes(self):
- data = r'backslash one \ backslash two \\ end'
+ data = r"backslash one \ backslash two \\ end"
self._literal_round_trip(String(40), [data], [data])
@@ -223,44 +211,32 @@ class _DateFixture(_LiteralRoundTripFixture):
@classmethod
def define_tables(cls, metadata):
- Table('date_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('date_data', cls.datatype),
- )
+ Table(
+ "date_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("date_data", cls.datatype),
+ )
def test_round_trip(self):
date_table = self.tables.date_table
- config.db.execute(
- date_table.insert(),
- {'date_data': self.data}
- )
+ config.db.execute(date_table.insert(), {"date_data": self.data})
- row = config.db.execute(
- select([
- date_table.c.date_data,
- ])
- ).first()
+ row = config.db.execute(select([date_table.c.date_data])).first()
compare = self.compare or self.data
- eq_(row,
- (compare, ))
+ eq_(row, (compare,))
assert isinstance(row[0], type(compare))
def test_null(self):
date_table = self.tables.date_table
- config.db.execute(
- date_table.insert(),
- {'date_data': None}
- )
+ config.db.execute(date_table.insert(), {"date_data": None})
- row = config.db.execute(
- select([
- date_table.c.date_data,
- ])
- ).first()
+ row = config.db.execute(select([date_table.c.date_data])).first()
eq_(row, (None,))
@testing.requires.datetime_literals
@@ -270,48 +246,49 @@ class _DateFixture(_LiteralRoundTripFixture):
class DateTimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime',
+ __requires__ = ("datetime",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime_microseconds',
+ __requires__ = ("datetime_microseconds",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'timestamp_microseconds',
+ __requires__ = ("timestamp_microseconds",)
__backend__ = True
datatype = TIMESTAMP
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
class TimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'time',
+ __requires__ = ("time",)
__backend__ = True
datatype = Time
data = datetime.time(12, 57, 18)
class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'time_microseconds',
+ __requires__ = ("time_microseconds",)
__backend__ = True
datatype = Time
data = datetime.time(12, 57, 18, 396)
class DateTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date',
+ __requires__ = ("date",)
__backend__ = True
datatype = Date
data = datetime.date(2012, 10, 15)
class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date', 'date_coerces_from_datetime'
+ __requires__ = "date", "date_coerces_from_datetime"
__backend__ = True
datatype = Date
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
@@ -319,14 +296,14 @@ class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime_historic',
+ __requires__ = ("datetime_historic",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(1850, 11, 10, 11, 52, 35)
class DateHistoricTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date_historic',
+ __requires__ = ("date_historic",)
__backend__ = True
datatype = Date
data = datetime.date(1727, 4, 1)
@@ -345,26 +322,21 @@ class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
def _round_trip(self, datatype, data):
metadata = self.metadata
int_table = Table(
- 'integer_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('integer_data', datatype),
+ "integer_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("integer_data", datatype),
)
metadata.create_all(config.db)
- config.db.execute(
- int_table.insert(),
- {'integer_data': data}
- )
+ config.db.execute(int_table.insert(), {"integer_data": data})
- row = config.db.execute(
- select([
- int_table.c.integer_data,
- ])
- ).first()
+ row = config.db.execute(select([int_table.c.integer_data])).first()
- eq_(row, (data, ))
+ eq_(row, (data,))
if util.py3k:
assert isinstance(row[0], int)
@@ -377,12 +349,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
@testing.provide_metadata
- def _do_test(self, type_, input_, output,
- filter_=None, check_scale=False):
+ def _do_test(self, type_, input_, output, filter_=None, check_scale=False):
metadata = self.metadata
- t = Table('t', metadata, Column('x', type_))
+ t = Table("t", metadata, Column("x", type_))
t.create()
- t.insert().execute([{'x': x} for x in input_])
+ t.insert().execute([{"x": x} for x in input_])
result = {row[0] for row in t.select().execute()}
output = set(output)
@@ -391,10 +362,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
output = set(filter_(x) for x in output)
eq_(result, output)
if check_scale:
- eq_(
- [str(x) for x in result],
- [str(x) for x in output],
- )
+ eq_([str(x) for x in result], [str(x) for x in output])
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
def test_render_literal_numeric(self):
@@ -416,8 +384,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._literal_round_trip(
Float(4),
[15.7563, decimal.Decimal("15.7563")],
- [15.7563, ],
- filter_=lambda n: n is not None and round(n, 5) or None
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
)
@testing.requires.precision_generic_float_type
@@ -425,8 +393,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._do_test(
Float(None, decimal_return_scale=7, asdecimal=True),
[15.7563827, decimal.Decimal("15.7563827")],
- [decimal.Decimal("15.7563827"), ],
- check_scale=True
+ [decimal.Decimal("15.7563827")],
+ check_scale=True,
)
def test_numeric_as_decimal(self):
@@ -445,18 +413,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
@testing.requires.fetch_null_from_numeric
def test_numeric_null_as_decimal(self):
- self._do_test(
- Numeric(precision=8, scale=4),
- [None],
- [None],
- )
+ self._do_test(Numeric(precision=8, scale=4), [None], [None])
@testing.requires.fetch_null_from_numeric
def test_numeric_null_as_float(self):
self._do_test(
- Numeric(precision=8, scale=4, asdecimal=False),
- [None],
- [None],
+ Numeric(precision=8, scale=4, asdecimal=False), [None], [None]
)
@testing.requires.floats_to_four_decimals
@@ -472,15 +434,13 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
Float(precision=8),
[15.7563, decimal.Decimal("15.7563")],
[15.7563],
- filter_=lambda n: n is not None and round(n, 5) or None
+ filter_=lambda n: n is not None and round(n, 5) or None,
)
def test_float_coerce_round_trip(self):
expr = 15.7563
- val = testing.db.scalar(
- select([literal(expr)])
- )
+ val = testing.db.scalar(select([literal(expr)]))
eq_(val, expr)
# this does not work in MySQL, see #4036, however we choose not
@@ -491,34 +451,28 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
def test_decimal_coerce_round_trip(self):
expr = decimal.Decimal("15.7563")
- val = testing.db.scalar(
- select([literal(expr)])
- )
+ val = testing.db.scalar(select([literal(expr)]))
eq_(val, expr)
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
def test_decimal_coerce_round_trip_w_cast(self):
expr = decimal.Decimal("15.7563")
- val = testing.db.scalar(
- select([cast(expr, Numeric(10, 4))])
- )
+ val = testing.db.scalar(select([cast(expr, Numeric(10, 4))]))
eq_(val, expr)
@testing.requires.precision_numerics_general
def test_precision_decimal(self):
- numbers = set([
- decimal.Decimal("54.234246451650"),
- decimal.Decimal("0.004354"),
- decimal.Decimal("900.0"),
- ])
-
- self._do_test(
- Numeric(precision=18, scale=12),
- numbers,
- numbers,
+ numbers = set(
+ [
+ decimal.Decimal("54.234246451650"),
+ decimal.Decimal("0.004354"),
+ decimal.Decimal("900.0"),
+ ]
)
+ self._do_test(Numeric(precision=18, scale=12), numbers, numbers)
+
@testing.requires.precision_numerics_enotation_large
def test_enotation_decimal(self):
"""test exceedingly small decimals.
@@ -528,25 +482,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
"""
- numbers = set([
- decimal.Decimal('1E-2'),
- decimal.Decimal('1E-3'),
- decimal.Decimal('1E-4'),
- decimal.Decimal('1E-5'),
- decimal.Decimal('1E-6'),
- decimal.Decimal('1E-7'),
- decimal.Decimal('1E-8'),
- decimal.Decimal("0.01000005940696"),
- decimal.Decimal("0.00000005940696"),
- decimal.Decimal("0.00000000000696"),
- decimal.Decimal("0.70000000000696"),
- decimal.Decimal("696E-12"),
- ])
- self._do_test(
- Numeric(precision=18, scale=14),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("1E-2"),
+ decimal.Decimal("1E-3"),
+ decimal.Decimal("1E-4"),
+ decimal.Decimal("1E-5"),
+ decimal.Decimal("1E-6"),
+ decimal.Decimal("1E-7"),
+ decimal.Decimal("1E-8"),
+ decimal.Decimal("0.01000005940696"),
+ decimal.Decimal("0.00000005940696"),
+ decimal.Decimal("0.00000000000696"),
+ decimal.Decimal("0.70000000000696"),
+ decimal.Decimal("696E-12"),
+ ]
)
+ self._do_test(Numeric(precision=18, scale=14), numbers, numbers)
@testing.requires.precision_numerics_enotation_large
def test_enotation_decimal_large(self):
@@ -554,41 +506,32 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
"""
- numbers = set([
- decimal.Decimal('4E+8'),
- decimal.Decimal("5748E+15"),
- decimal.Decimal('1.521E+15'),
- decimal.Decimal('00000000000000.1E+12'),
- ])
- self._do_test(
- Numeric(precision=25, scale=2),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("4E+8"),
+ decimal.Decimal("5748E+15"),
+ decimal.Decimal("1.521E+15"),
+ decimal.Decimal("00000000000000.1E+12"),
+ ]
)
+ self._do_test(Numeric(precision=25, scale=2), numbers, numbers)
@testing.requires.precision_numerics_many_significant_digits
def test_many_significant_digits(self):
- numbers = set([
- decimal.Decimal("31943874831932418390.01"),
- decimal.Decimal("319438950232418390.273596"),
- decimal.Decimal("87673.594069654243"),
- ])
- self._do_test(
- Numeric(precision=38, scale=12),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("31943874831932418390.01"),
+ decimal.Decimal("319438950232418390.273596"),
+ decimal.Decimal("87673.594069654243"),
+ ]
)
+ self._do_test(Numeric(precision=38, scale=12), numbers, numbers)
@testing.requires.precision_numerics_retains_significant_digits
def test_numeric_no_decimal(self):
- numbers = set([
- decimal.Decimal("1.000")
- ])
+ numbers = set([decimal.Decimal("1.000")])
self._do_test(
- Numeric(precision=5, scale=3),
- numbers,
- numbers,
- check_scale=True
+ Numeric(precision=5, scale=3), numbers, numbers, check_scale=True
)
@@ -597,42 +540,32 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('boolean_table', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('value', Boolean),
- Column('unconstrained_value', Boolean(create_constraint=False)),
- )
+ Table(
+ "boolean_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("value", Boolean),
+ Column("unconstrained_value", Boolean(create_constraint=False)),
+ )
def test_render_literal_bool(self):
- self._literal_round_trip(
- Boolean(),
- [True, False],
- [True, False]
- )
+ self._literal_round_trip(Boolean(), [True, False], [True, False])
def test_round_trip(self):
boolean_table = self.tables.boolean_table
config.db.execute(
boolean_table.insert(),
- {
- 'id': 1,
- 'value': True,
- 'unconstrained_value': False
- }
+ {"id": 1, "value": True, "unconstrained_value": False},
)
row = config.db.execute(
- select([
- boolean_table.c.value,
- boolean_table.c.unconstrained_value
- ])
+ select(
+ [boolean_table.c.value, boolean_table.c.unconstrained_value]
+ )
).first()
- eq_(
- row,
- (True, False)
- )
+ eq_(row, (True, False))
assert isinstance(row[0], bool)
def test_null(self):
@@ -640,24 +573,16 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
config.db.execute(
boolean_table.insert(),
- {
- 'id': 1,
- 'value': None,
- 'unconstrained_value': None
- }
+ {"id": 1, "value": None, "unconstrained_value": None},
)
row = config.db.execute(
- select([
- boolean_table.c.value,
- boolean_table.c.unconstrained_value
- ])
+ select(
+ [boolean_table.c.value, boolean_table.c.unconstrained_value]
+ )
).first()
- eq_(
- row,
- (None, None)
- )
+ eq_(row, (None, None))
def test_whereclause(self):
# testing "WHERE <column>" renders a compatible expression
@@ -667,92 +592,82 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
conn.execute(
boolean_table.insert(),
[
- {'id': 1, 'value': True, 'unconstrained_value': True},
- {'id': 2, 'value': False, 'unconstrained_value': False}
- ]
+ {"id": 1, "value": True, "unconstrained_value": True},
+ {"id": 2, "value": False, "unconstrained_value": False},
+ ],
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(boolean_table.c.value)
),
- 1
+ 1,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(
- boolean_table.c.unconstrained_value)
+ boolean_table.c.unconstrained_value
+ )
),
- 1
+ 1,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(~boolean_table.c.value)
),
- 2
+ 2,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(
- ~boolean_table.c.unconstrained_value)
+ ~boolean_table.c.unconstrained_value
+ )
),
- 2
+ 2,
)
-
-
class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
- __requires__ = 'json_type',
+ __requires__ = ("json_type",)
__backend__ = True
datatype = JSON
- data1 = {
- "key1": "value1",
- "key2": "value2"
- }
+ data1 = {"key1": "value1", "key2": "value2"}
data2 = {
"Key 'One'": "value1",
"key two": "value2",
- "key three": "value ' three '"
+ "key three": "value ' three '",
}
data3 = {
"key1": [1, 2, 3],
"key2": ["one", "two", "three"],
- "key3": [{"four": "five"}, {"six": "seven"}]
+ "key3": [{"four": "five"}, {"six": "seven"}],
}
data4 = ["one", "two", "three"]
data5 = {
"nested": {
- "elem1": [
- {"a": "b", "c": "d"},
- {"e": "f", "g": "h"}
- ],
- "elem2": {
- "elem3": {"elem4": "elem5"}
- }
+ "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}],
+ "elem2": {"elem3": {"elem4": "elem5"}},
}
}
- data6 = {
- "a": 5,
- "b": "some value",
- "c": {"foo": "bar"}
- }
+ data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}}
@classmethod
def define_tables(cls, metadata):
- Table('data_table', metadata,
- Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False),
- Column('data', cls.datatype),
- Column('nulldata', cls.datatype(none_as_null=True))
- )
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
def test_round_trip_data1(self):
self._test_round_trip(self.data1)
@@ -761,99 +676,82 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
data_table = self.tables.data_table
config.db.execute(
- data_table.insert(),
- {'name': 'row1', 'data': data_element}
+ data_table.insert(), {"name": "row1", "data": data_element}
)
- row = config.db.execute(
- select([
- data_table.c.data,
- ])
- ).first()
+ row = config.db.execute(select([data_table.c.data])).first()
- eq_(row, (data_element, ))
+ eq_(row, (data_element,))
def test_round_trip_none_as_sql_null(self):
- col = self.tables.data_table.c['nulldata']
+ col = self.tables.data_table.c["nulldata"]
with config.db.connect() as conn:
conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": None}
+ self.tables.data_table.insert(), {"name": "r1", "data": None}
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(col.is_(null()))
+ select([self.tables.data_table.c.name]).where(
+ col.is_(null())
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def test_round_trip_json_null_as_json_null(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
with config.db.connect() as conn:
conn.execute(
self.tables.data_table.insert(),
- {"name": "r1", "data": JSON.NULL}
+ {"name": "r1", "data": JSON.NULL},
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(cast(col, String) == 'null')
+ select([self.tables.data_table.c.name]).where(
+ cast(col, String) == "null"
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def test_round_trip_none_as_json_null(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
with config.db.connect() as conn:
conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": None}
+ self.tables.data_table.insert(), {"name": "r1", "data": None}
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(cast(col, String) == 'null')
+ select([self.tables.data_table.c.name]).where(
+ cast(col, String) == "null"
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def _criteria_fixture(self):
config.db.execute(
self.tables.data_table.insert(),
- [{"name": "r1", "data": self.data1},
- {"name": "r2", "data": self.data2},
- {"name": "r3", "data": self.data3},
- {"name": "r4", "data": self.data4},
- {"name": "r5", "data": self.data5},
- {"name": "r6", "data": self.data6}]
+ [
+ {"name": "r1", "data": self.data1},
+ {"name": "r2", "data": self.data2},
+ {"name": "r3", "data": self.data3},
+ {"name": "r4", "data": self.data4},
+ {"name": "r5", "data": self.data5},
+ {"name": "r6", "data": self.data6},
+ ],
)
def _test_index_criteria(self, crit, expected, test_literal=True):
@@ -861,20 +759,20 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
with config.db.connect() as conn:
stmt = select([self.tables.data_table.c.name]).where(crit)
- eq_(
- conn.scalar(stmt),
- expected
- )
+ eq_(conn.scalar(stmt), expected)
if test_literal:
- literal_sql = str(stmt.compile(
- config.db, compile_kwargs={"literal_binds": True}))
+ literal_sql = str(
+ stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}
+ )
+ )
eq_(conn.scalar(literal_sql), expected)
def test_crit_spaces_in_key(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
# limit the rows here to avoid PG error
# "cannot extract field from a non-object", which is
@@ -882,76 +780,74 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
self._test_index_criteria(
and_(
name.in_(["r1", "r2", "r3"]),
- cast(col["key two"], String) == '"value2"'
+ cast(col["key two"], String) == '"value2"',
),
- "r2"
+ "r2",
)
@config.requirements.json_array_indexes
def test_crit_simple_int(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
# limit the rows here to avoid PG error
# "cannot extract array element from a non-array", which is
# fixed in 9.4 but may exist in 9.3
self._test_index_criteria(
- and_(name == 'r4', cast(col[1], String) == '"two"'),
- "r4"
+ and_(name == "r4", cast(col[1], String) == '"two"'), "r4"
)
def test_crit_mixed_path(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- cast(col[("key3", 1, "six")], String) == '"seven"',
- "r3"
+ cast(col[("key3", 1, "six")], String) == '"seven"', "r3"
)
def test_crit_string_path(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
cast(col[("nested", "elem2", "elem3", "elem4")], String)
== '"elem5"',
- "r5"
+ "r5",
)
def test_crit_against_string_basic(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["b"], String) == '"some value"'),
- "r6"
+ and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6"
)
def test_crit_against_string_coerce_type(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6',
- cast(col["b"], String) == type_coerce("some value", JSON)),
+ and_(
+ name == "r6",
+ cast(col["b"], String) == type_coerce("some value", JSON),
+ ),
"r6",
- test_literal=False
+ test_literal=False,
)
def test_crit_against_int_basic(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["a"], String) == '5'),
- "r6"
+ and_(name == "r6", cast(col["a"], String) == "5"), "r6"
)
def test_crit_against_int_coerce_type(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["a"], String) == type_coerce(5, JSON)),
+ and_(name == "r6", cast(col["a"], String) == type_coerce(5, JSON)),
"r6",
- test_literal=False
+ test_literal=False,
)
def test_unicode_round_trip(self):
@@ -961,17 +857,17 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
{
"name": "r1",
"data": {
- util.u('réveillé'): util.u('réveillé'),
- "data": {"k1": util.u('drôle')}
- }
- }
+ util.u("réveillé"): util.u("réveillé"),
+ "data": {"k1": util.u("drôle")},
+ },
+ },
)
eq_(
conn.scalar(select([self.tables.data_table.c.data])),
{
- util.u('réveillé'): util.u('réveillé'),
- "data": {"k1": util.u('drôle')}
+ util.u("réveillé"): util.u("réveillé"),
+ "data": {"k1": util.u("drôle")},
},
)
@@ -986,7 +882,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
s = Session(testing.db)
- d1 = Data(name='d1', data=None, nulldata=None)
+ d1 = Data(name="d1", data=None, nulldata=None)
s.add(d1)
s.commit()
@@ -995,24 +891,46 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
)
eq_(
s.query(
- cast(self.tables.data_table.c.data, String(convert_unicode="force")),
- cast(self.tables.data_table.c.nulldata, String)
- ).filter(self.tables.data_table.c.name == 'd1').first(),
- ("null", None)
+ cast(
+ self.tables.data_table.c.data,
+ String(convert_unicode="force"),
+ ),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
)
eq_(
s.query(
- cast(self.tables.data_table.c.data, String(convert_unicode="force")),
- cast(self.tables.data_table.c.nulldata, String)
- ).filter(self.tables.data_table.c.name == 'd2').first(),
- ("null", None)
- )
-
-
-__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', 'JSONTest',
- 'DateTest', 'DateTimeTest', 'TextTest',
- 'NumericTest', 'IntegerTest',
- 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest',
- 'TimeMicrosecondsTest', 'TimestampMicrosecondsTest', 'TimeTest',
- 'DateTimeMicrosecondsTest',
- 'DateHistoricTest', 'StringTest', 'BooleanTest')
+ cast(
+ self.tables.data_table.c.data,
+ String(convert_unicode="force"),
+ ),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
+ )
+
+
+__all__ = (
+ "UnicodeVarcharTest",
+ "UnicodeTextTest",
+ "JSONTest",
+ "DateTest",
+ "DateTimeTest",
+ "TextTest",
+ "NumericTest",
+ "IntegerTest",
+ "DateTimeHistoricTest",
+ "DateTimeCoercedToDateTimeTest",
+ "TimeMicrosecondsTest",
+ "TimestampMicrosecondsTest",
+ "TimeTest",
+ "DateTimeMicrosecondsTest",
+ "DateHistoricTest",
+ "StringTest",
+ "BooleanTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
index e4c61e74a..b232c3a78 100644
--- a/lib/sqlalchemy/testing/suite/test_update_delete.py
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -6,15 +6,17 @@ from ..schema import Table, Column
class SimpleUpdateDeleteTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('plain_pk', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -24,40 +26,29 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest):
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
- ]
+ ],
)
def test_update(self):
t = self.tables.plain_pk
- r = config.db.execute(
- t.update().where(t.c.id == 2),
- data="d2_new"
- )
+ r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new")
assert not r.is_insert
assert not r.returns_rows
eq_(
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
- [
- (1, "d1"),
- (2, "d2_new"),
- (3, "d3")
- ]
+ [(1, "d1"), (2, "d2_new"), (3, "d3")],
)
def test_delete(self):
t = self.tables.plain_pk
- r = config.db.execute(
- t.delete().where(t.c.id == 2)
- )
+ r = config.db.execute(t.delete().where(t.c.id == 2))
assert not r.is_insert
assert not r.returns_rows
eq_(
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
- [
- (1, "d1"),
- (3, "d3")
- ]
+ [(1, "d1"), (3, "d3")],
)
-__all__ = ('SimpleUpdateDeleteTest', )
+
+__all__ = ("SimpleUpdateDeleteTest",)
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
index 409d3bda5..5b015d214 100644
--- a/lib/sqlalchemy/testing/util.py
+++ b/lib/sqlalchemy/testing/util.py
@@ -14,6 +14,7 @@ import sys
import types
if jython:
+
def jython_gc_collect(*args):
"""aggressive gc.collect for tests."""
gc.collect()
@@ -25,9 +26,11 @@ if jython:
# "lazy" gc, for VM's that don't GC on refcount == 0
gc_collect = lazy_gc = jython_gc_collect
elif pypy:
+
def pypy_gc_collect(*args):
gc.collect()
gc.collect()
+
gc_collect = lazy_gc = pypy_gc_collect
else:
# assume CPython - straight gc.collect, lazy_gc() is a pass
@@ -42,11 +45,13 @@ def picklers():
if py2k:
try:
import cPickle
+
picklers.add(cPickle)
except ImportError:
pass
import pickle
+
picklers.add(pickle)
# yes, this thing needs this much testing
@@ -60,9 +65,9 @@ def round_decimal(value, prec):
return round(value, prec)
# can also use shift() here but that is 2.6 only
- return (value * decimal.Decimal("1" + "0" * prec)
- ).to_integral(decimal.ROUND_FLOOR) / \
- pow(10, prec)
+ return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
+ decimal.ROUND_FLOOR
+ ) / pow(10, prec)
class RandomSet(set):
@@ -137,8 +142,9 @@ def function_named(fn, name):
try:
fn.__name__ = name
except TypeError:
- fn = types.FunctionType(fn.__code__, fn.__globals__, name,
- fn.__defaults__, fn.__closure__)
+ fn = types.FunctionType(
+ fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
+ )
return fn
@@ -190,7 +196,7 @@ def provide_metadata(fn, *args, **kw):
metadata = schema.MetaData(config.db)
self = args[0]
- prev_meta = getattr(self, 'metadata', None)
+ prev_meta = getattr(self, "metadata", None)
self.metadata = metadata
try:
return fn(*args, **kw)
@@ -213,8 +219,8 @@ def force_drop_names(*names):
try:
return fn(*args, **kw)
finally:
- drop_all_tables(
- config.db, inspect(config.db), include_names=names)
+ drop_all_tables(config.db, inspect(config.db), include_names=names)
+
return go
@@ -234,8 +240,13 @@ class adict(dict):
def drop_all_tables(engine, inspector, schema=None, include_names=None):
- from sqlalchemy import Column, Table, Integer, MetaData, \
- ForeignKeyConstraint
+ from sqlalchemy import (
+ Column,
+ Table,
+ Integer,
+ MetaData,
+ ForeignKeyConstraint,
+ )
from sqlalchemy.schema import DropTable, DropConstraint
if include_names is not None:
@@ -243,30 +254,35 @@ def drop_all_tables(engine, inspector, schema=None, include_names=None):
with engine.connect() as conn:
for tname, fkcs in reversed(
- inspector.get_sorted_table_and_fkc_names(schema=schema)):
+ inspector.get_sorted_table_and_fkc_names(schema=schema)
+ ):
if tname:
if include_names is not None and tname not in include_names:
continue
- conn.execute(DropTable(
- Table(tname, MetaData(), schema=schema)
- ))
+ conn.execute(
+ DropTable(Table(tname, MetaData(), schema=schema))
+ )
elif fkcs:
if not engine.dialect.supports_alter:
continue
for tname, fkc in fkcs:
- if include_names is not None and \
- tname not in include_names:
+ if (
+ include_names is not None
+ and tname not in include_names
+ ):
continue
tb = Table(
- tname, MetaData(),
- Column('x', Integer),
- Column('y', Integer),
- schema=schema
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
+ )
)
- conn.execute(DropConstraint(
- ForeignKeyConstraint(
- [tb.c.x], [tb.c.y], name=fkc)
- ))
def teardown_events(event_cls):
@@ -276,5 +292,5 @@ def teardown_events(event_cls):
return fn(*arg, **kw)
finally:
event_cls._clear()
- return decorate
+ return decorate
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
index 46e7c54db..e0101b14d 100644
--- a/lib/sqlalchemy/testing/warnings.py
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -15,17 +15,20 @@ from . import assertions
def setup_filters():
"""Set global warning behavior for the test suite."""
- warnings.filterwarnings('ignore',
- category=sa_exc.SAPendingDeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SAWarning)
+ warnings.filterwarnings(
+ "ignore", category=sa_exc.SAPendingDeprecationWarning
+ )
+ warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings("error", category=sa_exc.SAWarning)
# some selected deprecations...
- warnings.filterwarnings('error', category=DeprecationWarning)
+ warnings.filterwarnings("error", category=DeprecationWarning)
warnings.filterwarnings(
- "ignore", category=DeprecationWarning, message=".*StopIteration")
+ "ignore", category=DeprecationWarning, message=".*StopIteration"
+ )
warnings.filterwarnings(
- "ignore", category=DeprecationWarning, message=".*inspect.getargspec")
+ "ignore", category=DeprecationWarning, message=".*inspect.getargspec"
+ )
def assert_warnings(fn, warning_msgs, regex=False):
@@ -36,6 +39,6 @@ def assert_warnings(fn, warning_msgs, regex=False):
"""
with assertions._expect_warnings(
- sa_exc.SAWarning, warning_msgs, regex=regex):
+ sa_exc.SAWarning, warning_msgs, regex=regex
+ ):
return fn()
-
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index cd0ded7d2..e66582801 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -9,15 +9,55 @@
"""
-__all__ = ['TypeEngine', 'TypeDecorator', 'UserDefinedType',
- 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text',
- 'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME',
- 'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT',
- 'SMALLINT', 'INTEGER', 'DATE', 'TIME', 'String', 'Integer',
- 'SmallInteger', 'BigInteger', 'Numeric', 'Float', 'DateTime',
- 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', 'Unicode',
- 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', 'Enum',
- 'Indexable', 'ARRAY', 'JSON']
+__all__ = [
+ "TypeEngine",
+ "TypeDecorator",
+ "UserDefinedType",
+ "INT",
+ "CHAR",
+ "VARCHAR",
+ "NCHAR",
+ "NVARCHAR",
+ "TEXT",
+ "Text",
+ "FLOAT",
+ "NUMERIC",
+ "REAL",
+ "DECIMAL",
+ "TIMESTAMP",
+ "DATETIME",
+ "CLOB",
+ "BLOB",
+ "BINARY",
+ "VARBINARY",
+ "BOOLEAN",
+ "BIGINT",
+ "SMALLINT",
+ "INTEGER",
+ "DATE",
+ "TIME",
+ "String",
+ "Integer",
+ "SmallInteger",
+ "BigInteger",
+ "Numeric",
+ "Float",
+ "DateTime",
+ "Date",
+ "Time",
+ "LargeBinary",
+ "Binary",
+ "Boolean",
+ "Unicode",
+ "Concatenable",
+ "UnicodeText",
+ "PickleType",
+ "Interval",
+ "Enum",
+ "Indexable",
+ "ARRAY",
+ "JSON",
+]
from .sql.type_api import (
adapt_type,
@@ -25,7 +65,7 @@ from .sql.type_api import (
TypeDecorator,
Variant,
to_instance,
- UserDefinedType
+ UserDefinedType,
)
from .sql.sqltypes import (
ARRAY,
@@ -78,4 +118,4 @@ from .sql.sqltypes import (
UnicodeText,
VARBINARY,
VARCHAR,
- )
+)
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index d8c28d6af..103225e2a 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -5,42 +5,146 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from .compat import callable, cmp, reduce, \
- threading, py3k, py33, py36, py2k, jython, pypy, cpython, win32, \
- pickle, dottedgetter, parse_qsl, namedtuple, next, reraise, \
- raise_from_cause, text_type, safe_kwarg, string_types, int_types, \
- binary_type, nested, \
- quote_plus, with_metaclass, print_, itertools_filterfalse, u, ue, b,\
- unquote_plus, unquote, b64decode, b64encode, byte_buffer, itertools_filter,\
- iterbytes, StringIO, inspect_getargspec, zip_longest
+from .compat import (
+ callable,
+ cmp,
+ reduce,
+ threading,
+ py3k,
+ py33,
+ py36,
+ py2k,
+ jython,
+ pypy,
+ cpython,
+ win32,
+ pickle,
+ dottedgetter,
+ parse_qsl,
+ namedtuple,
+ next,
+ reraise,
+ raise_from_cause,
+ text_type,
+ safe_kwarg,
+ string_types,
+ int_types,
+ binary_type,
+ nested,
+ quote_plus,
+ with_metaclass,
+ print_,
+ itertools_filterfalse,
+ u,
+ ue,
+ b,
+ unquote_plus,
+ unquote,
+ b64decode,
+ b64encode,
+ byte_buffer,
+ itertools_filter,
+ iterbytes,
+ StringIO,
+ inspect_getargspec,
+ zip_longest,
+)
-from ._collections import KeyedTuple, ImmutableContainer, immutabledict, \
- Properties, OrderedProperties, ImmutableProperties, OrderedDict, \
- OrderedSet, IdentitySet, OrderedIdentitySet, column_set, \
- column_dict, ordered_column_set, populate_column_dict, unique_list, \
- UniqueAppender, PopulateDict, EMPTY_SET, to_list, to_set, \
- to_column_set, update_copy, flatten_iterator, has_intersection, \
- LRUCache, ScopedRegistry, ThreadLocalRegistry, WeakSequence, \
- coerce_generator_arg, lightweight_named_tuple, collections_abc, \
- has_dupes
+from ._collections import (
+ KeyedTuple,
+ ImmutableContainer,
+ immutabledict,
+ Properties,
+ OrderedProperties,
+ ImmutableProperties,
+ OrderedDict,
+ OrderedSet,
+ IdentitySet,
+ OrderedIdentitySet,
+ column_set,
+ column_dict,
+ ordered_column_set,
+ populate_column_dict,
+ unique_list,
+ UniqueAppender,
+ PopulateDict,
+ EMPTY_SET,
+ to_list,
+ to_set,
+ to_column_set,
+ update_copy,
+ flatten_iterator,
+ has_intersection,
+ LRUCache,
+ ScopedRegistry,
+ ThreadLocalRegistry,
+ WeakSequence,
+ coerce_generator_arg,
+ lightweight_named_tuple,
+ collections_abc,
+ has_dupes,
+)
-from .langhelpers import iterate_attributes, class_hierarchy, \
- portable_instancemethod, unbound_method_to_callable, \
- getargspec_init, format_argspec_init, format_argspec_plus, \
- get_func_kwargs, get_cls_kwargs, decorator, as_interface, \
- memoized_property, memoized_instancemethod, md5_hex, \
- group_expirable_memoized_property, dependencies, decode_slice, \
- monkeypatch_proxied_specials, asbool, bool_or_str, coerce_kw_type,\
- duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\
- classproperty, set_creation_order, warn_exception, warn, NoneType,\
- constructor_copy, methods_equivalent, chop_traceback, asint,\
- generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \
- safe_reraise, quoted_token_parser,\
- get_callable_argspec, only_once, attrsetter, ellipses_string, \
- warn_limited, map_bits, MemoizedSlots, EnsureKWArgType, wrap_callable
+from .langhelpers import (
+ iterate_attributes,
+ class_hierarchy,
+ portable_instancemethod,
+ unbound_method_to_callable,
+ getargspec_init,
+ format_argspec_init,
+ format_argspec_plus,
+ get_func_kwargs,
+ get_cls_kwargs,
+ decorator,
+ as_interface,
+ memoized_property,
+ memoized_instancemethod,
+ md5_hex,
+ group_expirable_memoized_property,
+ dependencies,
+ decode_slice,
+ monkeypatch_proxied_specials,
+ asbool,
+ bool_or_str,
+ coerce_kw_type,
+ duck_type_collection,
+ assert_arg_type,
+ symbol,
+ dictlike_iteritems,
+ classproperty,
+ set_creation_order,
+ warn_exception,
+ warn,
+ NoneType,
+ constructor_copy,
+ methods_equivalent,
+ chop_traceback,
+ asint,
+ generic_repr,
+ counter,
+ PluginLoader,
+ hybridproperty,
+ hybridmethod,
+ safe_reraise,
+ quoted_token_parser,
+ get_callable_argspec,
+ only_once,
+ attrsetter,
+ ellipses_string,
+ warn_limited,
+ map_bits,
+ MemoizedSlots,
+ EnsureKWArgType,
+ wrap_callable,
+)
-from .deprecations import warn_deprecated, warn_pending_deprecation, \
- deprecated, pending_deprecation, inject_docstring_text
+from .deprecations import (
+ warn_deprecated,
+ warn_pending_deprecation,
+ deprecated,
+ pending_deprecation,
+ inject_docstring_text,
+)
# things that used to be not always available,
# but are now as of current support Python versions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 43440134a..67be0e6bf 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -10,8 +10,13 @@
from __future__ import absolute_import
import weakref
import operator
-from .compat import threading, itertools_filterfalse, string_types, \
- binary_types, collections_abc
+from .compat import (
+ threading,
+ itertools_filterfalse,
+ string_types,
+ binary_types,
+ collections_abc,
+)
from . import py2k
import types
@@ -77,7 +82,7 @@ class KeyedTuple(AbstractKeyedTuple):
t.__dict__.update(zip(labels, vals))
else:
labels = []
- t.__dict__['_labels'] = labels
+ t.__dict__["_labels"] = labels
return t
@property
@@ -139,8 +144,7 @@ class ImmutableContainer(object):
class immutabledict(ImmutableContainer, dict):
- clear = pop = popitem = setdefault = \
- update = ImmutableContainer._immutable
+ clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
def __new__(cls, *args):
new = dict.__new__(cls)
@@ -151,7 +155,7 @@ class immutabledict(ImmutableContainer, dict):
pass
def __reduce__(self):
- return immutabledict, (dict(self), )
+ return immutabledict, (dict(self),)
def union(self, d):
if not d:
@@ -173,10 +177,10 @@ class immutabledict(ImmutableContainer, dict):
class Properties(object):
"""Provide a __getattr__/__setattr__ interface over a dict."""
- __slots__ = '_data',
+ __slots__ = ("_data",)
def __init__(self, data):
- object.__setattr__(self, '_data', data)
+ object.__setattr__(self, "_data", data)
def __len__(self):
return len(self._data)
@@ -185,7 +189,9 @@ class Properties(object):
return iter(list(self._data.values()))
def __dir__(self):
- return dir(super(Properties, self)) + [str(k) for k in self._data.keys()]
+ return dir(super(Properties, self)) + [
+ str(k) for k in self._data.keys()
+ ]
def __add__(self, other):
return list(self) + list(other)
@@ -203,10 +209,10 @@ class Properties(object):
self._data[key] = obj
def __getstate__(self):
- return {'_data': self._data}
+ return {"_data": self._data}
def __setstate__(self, state):
- object.__setattr__(self, '_data', state['_data'])
+ object.__setattr__(self, "_data", state["_data"])
def __getattr__(self, key):
try:
@@ -266,7 +272,7 @@ class ImmutableProperties(ImmutableContainer, Properties):
class OrderedDict(dict):
"""A dict that returns keys/values/items in the order they were added."""
- __slots__ = '_list',
+ __slots__ = ("_list",)
def __reduce__(self):
return OrderedDict, (self.items(),)
@@ -294,7 +300,7 @@ class OrderedDict(dict):
def update(self, ____sequence=None, **kwargs):
if ____sequence is not None:
- if hasattr(____sequence, 'keys'):
+ if hasattr(____sequence, "keys"):
for key in ____sequence.keys():
self.__setitem__(key, ____sequence[key])
else:
@@ -323,6 +329,7 @@ class OrderedDict(dict):
return [(key, self[key]) for key in self._list]
if py2k:
+
def itervalues(self):
return iter(self.values())
@@ -402,7 +409,7 @@ class OrderedSet(set):
return self.union(other)
def __repr__(self):
- return '%s(%r)' % (self.__class__.__name__, self._list)
+ return "%s(%r)" % (self.__class__.__name__, self._list)
__str__ = __repr__
@@ -502,13 +509,13 @@ class IdentitySet(object):
pair = self._members.popitem()
return pair[1]
except KeyError:
- raise KeyError('pop from an empty set')
+ raise KeyError("pop from an empty set")
def clear(self):
self._members.clear()
def __cmp__(self, other):
- raise TypeError('cannot compare sets using cmp()')
+ raise TypeError("cannot compare sets using cmp()")
def __eq__(self, other):
if isinstance(other, IdentitySet):
@@ -527,8 +534,9 @@ class IdentitySet(object):
if len(self) > len(other):
return False
- for m in itertools_filterfalse(other._members.__contains__,
- iter(self._members.keys())):
+ for m in itertools_filterfalse(
+ other._members.__contains__, iter(self._members.keys())
+ ):
return False
return True
@@ -548,8 +556,9 @@ class IdentitySet(object):
if len(self) < len(other):
return False
- for m in itertools_filterfalse(self._members.__contains__,
- iter(other._members.keys())):
+ for m in itertools_filterfalse(
+ self._members.__contains__, iter(other._members.keys())
+ ):
return False
return True
@@ -635,7 +644,8 @@ class IdentitySet(object):
members = self._member_id_tuples()
other = _iter_id(iterable)
result._members.update(
- self._working_set(members).symmetric_difference(other))
+ self._working_set(members).symmetric_difference(other)
+ )
return result
def _member_id_tuples(self):
@@ -667,10 +677,10 @@ class IdentitySet(object):
return iter(self._members.values())
def __hash__(self):
- raise TypeError('set objects are unhashable')
+ raise TypeError("set objects are unhashable")
def __repr__(self):
- return '%s(%r)' % (type(self).__name__, list(self._members.values()))
+ return "%s(%r)" % (type(self).__name__, list(self._members.values()))
class WeakSequence(object):
@@ -689,8 +699,9 @@ class WeakSequence(object):
return len(self._storage)
def __iter__(self):
- return (obj for obj in
- (ref() for ref in self._storage) if obj is not None)
+ return (
+ obj for obj in (ref() for ref in self._storage) if obj is not None
+ )
def __getitem__(self, index):
try:
@@ -732,6 +743,7 @@ class PopulateDict(dict):
self[key] = val = self.creator(key)
return val
+
# Define collections that are capable of storing
# ColumnElement objects as hashable keys/elements.
# At this point, these are mostly historical, things
@@ -745,20 +757,21 @@ populate_column_dict = PopulateDict
_getters = PopulateDict(operator.itemgetter)
_property_getters = PopulateDict(
- lambda idx: property(operator.itemgetter(idx)))
+ lambda idx: property(operator.itemgetter(idx))
+)
def unique_list(seq, hashfunc=None):
seen = set()
seen_add = seen.add
if not hashfunc:
- return [x for x in seq
- if x not in seen
- and not seen_add(x)]
+ return [x for x in seq if x not in seen and not seen_add(x)]
else:
- return [x for x in seq
- if hashfunc(x) not in seen
- and not seen_add(hashfunc(x))]
+ return [
+ x
+ for x in seq
+ if hashfunc(x) not in seen and not seen_add(hashfunc(x))
+ ]
class UniqueAppender(object):
@@ -773,9 +786,9 @@ class UniqueAppender(object):
self._unique = {}
if via:
self._data_appender = getattr(data, via)
- elif hasattr(data, 'append'):
+ elif hasattr(data, "append"):
self._data_appender = data.append
- elif hasattr(data, 'add'):
+ elif hasattr(data, "add"):
self._data_appender = data.add
def append(self, item):
@@ -798,8 +811,9 @@ def coerce_generator_arg(arg):
def to_list(x, default=None):
if x is None:
return default
- if not isinstance(x, collections_abc.Iterable) or \
- isinstance(x, string_types + binary_types):
+ if not isinstance(x, collections_abc.Iterable) or isinstance(
+ x, string_types + binary_types
+ ):
return [x]
elif isinstance(x, list):
return x
@@ -815,9 +829,7 @@ def has_intersection(set_, iterable):
"""
# TODO: optimize, write in C, etc.
- return bool(
- set_.intersection([i for i in iterable if i.__hash__])
- )
+ return bool(set_.intersection([i for i in iterable if i.__hash__]))
def to_set(x):
@@ -854,7 +866,7 @@ def flatten_iterator(x):
"""
for elem in x:
- if not isinstance(elem, str) and hasattr(elem, '__iter__'):
+ if not isinstance(elem, str) and hasattr(elem, "__iter__"):
for y in flatten_iterator(elem):
yield y
else:
@@ -871,9 +883,9 @@ class LRUCache(dict):
"""
- __slots__ = 'capacity', 'threshold', 'size_alert', '_counter', '_mutex'
+ __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
- def __init__(self, capacity=100, threshold=.5, size_alert=None):
+ def __init__(self, capacity=100, threshold=0.5, size_alert=None):
self.capacity = capacity
self.threshold = threshold
self.size_alert = size_alert
@@ -929,10 +941,10 @@ class LRUCache(dict):
if size_alert:
size_alert = False
self.size_alert(self)
- by_counter = sorted(dict.values(self),
- key=operator.itemgetter(2),
- reverse=True)
- for item in by_counter[self.capacity:]:
+ by_counter = sorted(
+ dict.values(self), key=operator.itemgetter(2), reverse=True
+ )
+ for item in by_counter[self.capacity :]:
try:
del self[item[0]]
except KeyError:
@@ -946,17 +958,22 @@ _lw_tuples = LRUCache(100)
def lightweight_named_tuple(name, fields):
- hash_ = (name, ) + tuple(fields)
+ hash_ = (name,) + tuple(fields)
tp_cls = _lw_tuples.get(hash_)
if tp_cls:
return tp_cls
tp_cls = type(
- name, (_LW,),
- dict([
- (field, _property_getters[idx])
- for idx, field in enumerate(fields) if field is not None
- ] + [('__slots__', ())])
+ name,
+ (_LW,),
+ dict(
+ [
+ (field, _property_getters[idx])
+ for idx, field in enumerate(fields)
+ if field is not None
+ ]
+ + [("__slots__", ())]
+ ),
)
tp_cls._real_fields = fields
@@ -1077,6 +1094,7 @@ def has_dupes(sequence, target):
return True
return False
+
# .index version. the two __contains__ calls as well
# as .index() and isinstance() slow this down.
# def has_dupes(sequence, target):
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index b01471edf..553624b49 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -20,9 +20,9 @@ py32 = sys.version_info >= (3, 2)
py3k = sys.version_info >= (3, 0)
py2k = sys.version_info < (3, 0)
py265 = sys.version_info >= (2, 6, 5)
-jython = sys.platform.startswith('java')
-pypy = hasattr(sys, 'pypy_version_info')
-win32 = sys.platform.startswith('win')
+jython = sys.platform.startswith("java")
+pypy = hasattr(sys, "pypy_version_info")
+win32 = sys.platform.startswith("win")
cpython = not pypy and not jython # TODO: something better for this ?
contextmanager = contextlib.contextmanager
@@ -30,8 +30,9 @@ dottedgetter = operator.attrgetter
namedtuple = collections.namedtuple
next = next
-ArgSpec = collections.namedtuple("ArgSpec",
- ["args", "varargs", "keywords", "defaults"])
+ArgSpec = collections.namedtuple(
+ "ArgSpec", ["args", "varargs", "keywords", "defaults"]
+)
try:
import threading
@@ -58,40 +59,43 @@ if py3k:
from io import BytesIO as byte_buffer
from io import StringIO
from itertools import zip_longest
- from urllib.parse import (quote_plus, unquote_plus, parse_qsl, quote, unquote)
-
- string_types = str,
- binary_types = bytes,
+ from urllib.parse import (
+ quote_plus,
+ unquote_plus,
+ parse_qsl,
+ quote,
+ unquote,
+ )
+
+ string_types = (str,)
+ binary_types = (bytes,)
binary_type = bytes
text_type = str
- int_types = int,
+ int_types = (int,)
iterbytes = iter
itertools_filterfalse = itertools.filterfalse
itertools_filter = filter
itertools_imap = map
- exec_ = getattr(builtins, 'exec')
- import_ = getattr(builtins, '__import__')
+ exec_ = getattr(builtins, "exec")
+ import_ = getattr(builtins, "__import__")
print_ = getattr(builtins, "print")
def b(s):
return s.encode("latin-1")
def b64decode(x):
- return base64.b64decode(x.encode('ascii'))
-
+ return base64.b64decode(x.encode("ascii"))
def b64encode(x):
- return base64.b64encode(x).decode('ascii')
+ return base64.b64encode(x).decode("ascii")
def cmp(a, b):
return (a > b) - (a < b)
def inspect_getargspec(func):
- return ArgSpec(
- *inspect_getfullargspec(func)[0:4]
- )
+ return ArgSpec(*inspect_getfullargspec(func)[0:4])
def reraise(tp, value, tb=None, cause=None):
if cause is not None:
@@ -110,8 +114,11 @@ if py3k:
if py32:
callable = callable
else:
+
def callable(fn):
- return hasattr(fn, '__call__')
+ return hasattr(fn, "__call__")
+
+
else:
import base64
import ConfigParser as configparser
@@ -129,8 +136,8 @@ else:
except ImportError:
import pickle
- string_types = basestring,
- binary_types = bytes,
+ string_types = (basestring,)
+ binary_types = (bytes,)
binary_type = str
text_type = unicode
int_types = int, long
@@ -153,9 +160,9 @@ else:
def exec_(func_text, globals_, lcl=None):
if lcl is None:
- exec('exec func_text in globals_')
+ exec("exec func_text in globals_")
else:
- exec('exec func_text in globals_, lcl')
+ exec("exec func_text in globals_, lcl")
def iterbytes(buf):
return (ord(byte) for byte in buf)
@@ -186,24 +193,32 @@ else:
# not as nice as that of Py3K, but at least preserves
# the code line where the issue occurred
- exec("def reraise(tp, value, tb=None, cause=None):\n"
- " if cause is not None:\n"
- " assert cause is not value, 'Same cause emitted'\n"
- " raise tp, value, tb\n")
+ exec(
+ "def reraise(tp, value, tb=None, cause=None):\n"
+ " if cause is not None:\n"
+ " assert cause is not value, 'Same cause emitted'\n"
+ " raise tp, value, tb\n"
+ )
if py35:
from inspect import formatannotation
def inspect_formatargspec(
- args, varargs=None, varkw=None, defaults=None,
- kwonlyargs=(), kwonlydefaults={}, annotations={},
- formatarg=str,
- formatvarargs=lambda name: '*' + name,
- formatvarkw=lambda name: '**' + name,
- formatvalue=lambda value: '=' + repr(value),
- formatreturns=lambda text: ' -> ' + text,
- formatannotation=formatannotation):
+ args,
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=(),
+ kwonlydefaults={},
+ annotations={},
+ formatarg=str,
+ formatvarargs=lambda name: "*" + name,
+ formatvarkw=lambda name: "**" + name,
+ formatvalue=lambda value: "=" + repr(value),
+ formatreturns=lambda text: " -> " + text,
+ formatannotation=formatannotation,
+ ):
"""Copy formatargspec from python 3.7 standard library.
Python 3 has deprecated formatargspec and requested that Signature
@@ -221,7 +236,7 @@ if py35:
def formatargandannotation(arg):
result = formatarg(arg)
if arg in annotations:
- result += ': ' + formatannotation(annotations[arg])
+ result += ": " + formatannotation(annotations[arg])
return result
specs = []
@@ -237,7 +252,7 @@ if py35:
specs.append(formatvarargs(formatargandannotation(varargs)))
else:
if kwonlyargs:
- specs.append('*')
+ specs.append("*")
if kwonlyargs:
for kwonlyarg in kwonlyargs:
@@ -249,10 +264,12 @@ if py35:
if varkw is not None:
specs.append(formatvarkw(formatargandannotation(varkw)))
- result = '(' + ', '.join(specs) + ')'
- if 'return' in annotations:
- result += formatreturns(formatannotation(annotations['return']))
+ result = "(" + ", ".join(specs) + ")"
+ if "return" in annotations:
+ result += formatreturns(formatannotation(annotations["return"]))
return result
+
+
else:
from inspect import formatargspec as inspect_formatargspec
@@ -330,4 +347,5 @@ def with_metaclass(meta, *bases):
if this_bases is None:
return type.__new__(cls, name, (), d)
return meta(name, bases, d)
- return metaclass('temporary_class', None, {})
+
+ return metaclass("temporary_class", None, {})
diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py
index 9000cc795..e6612f075 100644
--- a/lib/sqlalchemy/util/deprecations.py
+++ b/lib/sqlalchemy/util/deprecations.py
@@ -40,8 +40,7 @@ def deprecated(version, message=None, add_deprecation_to_docstring=True):
"""
if add_deprecation_to_docstring:
- header = ".. deprecated:: %s %s" % \
- (version, (message or ''))
+ header = ".. deprecated:: %s %s" % (version, (message or ""))
else:
header = None
@@ -50,13 +49,18 @@ def deprecated(version, message=None, add_deprecation_to_docstring=True):
def decorate(fn):
return _decorate_with_warning(
- fn, exc.SADeprecationWarning,
- message % dict(func=fn.__name__), header)
+ fn,
+ exc.SADeprecationWarning,
+ message % dict(func=fn.__name__),
+ header,
+ )
+
return decorate
-def pending_deprecation(version, message=None,
- add_deprecation_to_docstring=True):
+def pending_deprecation(
+ version, message=None, add_deprecation_to_docstring=True
+):
"""Decorates a function and issues a pending deprecation warning on use.
:param version:
@@ -74,8 +78,7 @@ def pending_deprecation(version, message=None,
"""
if add_deprecation_to_docstring:
- header = ".. deprecated:: %s (pending) %s" % \
- (version, (message or ''))
+ header = ".. deprecated:: %s (pending) %s" % (version, (message or ""))
else:
header = None
@@ -84,8 +87,12 @@ def pending_deprecation(version, message=None,
def decorate(fn):
return _decorate_with_warning(
- fn, exc.SAPendingDeprecationWarning,
- message % dict(func=fn.__name__), header)
+ fn,
+ exc.SAPendingDeprecationWarning,
+ message % dict(func=fn.__name__),
+ header,
+ )
+
return decorate
@@ -95,7 +102,8 @@ def _sanitize_restructured_text(text):
if type_ in ("func", "meth"):
name += "()"
return name
- return re.sub(r'\:(\w+)\:`~?\.?(.+?)`', repl, text)
+
+ return re.sub(r"\:(\w+)\:`~?\.?(.+?)`", repl, text)
def _decorate_with_warning(func, wtype, message, docstring_header=None):
@@ -108,7 +116,7 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None):
warnings.warn(message, wtype, stacklevel=3)
return fn(*args, **kwargs)
- doc = func.__doc__ is not None and func.__doc__ or ''
+ doc = func.__doc__ is not None and func.__doc__ or ""
if docstring_header is not None:
docstring_header %= dict(func=func.__name__)
@@ -118,6 +126,7 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None):
decorated.__doc__ = doc
return decorated
+
import textwrap
@@ -135,7 +144,7 @@ def _dedent_docstring(text):
def inject_docstring_text(doctext, injecttext, pos):
doctext = _dedent_docstring(doctext or "")
- lines = doctext.split('\n')
+ lines = doctext.split("\n")
injectlines = textwrap.dedent(injecttext).split("\n")
if injectlines[0]:
injectlines.insert(0, "")
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 7e387f4f2..6a286998b 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -25,7 +25,7 @@ from . import _collections
def md5_hex(x):
if compat.py3k:
- x = x.encode('utf-8')
+ x = x.encode("utf-8")
m = hashlib.md5()
m.update(x)
return m.hexdigest()
@@ -49,7 +49,7 @@ class safe_reraise(object):
"""
- __slots__ = ('warn_only', '_exc_info')
+ __slots__ = ("warn_only", "_exc_info")
def __init__(self, warn_only=False):
self.warn_only = warn_only
@@ -61,7 +61,7 @@ class safe_reraise(object):
# see #2703 for notes
if type_ is None:
exc_type, exc_value, exc_tb = self._exc_info
- self._exc_info = None # remove potential circular references
+ self._exc_info = None # remove potential circular references
if not self.warn_only:
compat.reraise(exc_type, exc_value, exc_tb)
else:
@@ -71,8 +71,9 @@ class safe_reraise(object):
warn(
"An exception has occurred during handling of a "
"previous exception. The previous exception "
- "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1]))
- self._exc_info = None # remove potential circular references
+ "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1])
+ )
+ self._exc_info = None # remove potential circular references
compat.reraise(type_, value, traceback)
@@ -84,7 +85,7 @@ def decode_slice(slc):
"""
ret = []
for x in slc.start, slc.stop, slc.step:
- if hasattr(x, '__index__'):
+ if hasattr(x, "__index__"):
x = x.__index__()
ret.append(x)
return tuple(ret)
@@ -93,9 +94,10 @@ def decode_slice(slc):
def _unique_symbols(used, *bases):
used = set(used)
for base in bases:
- pool = itertools.chain((base,),
- compat.itertools_imap(lambda i: base + str(i),
- range(1000)))
+ pool = itertools.chain(
+ (base,),
+ compat.itertools_imap(lambda i: base + str(i), range(1000)),
+ )
for sym in pool:
if sym not in used:
used.add(sym)
@@ -122,21 +124,25 @@ def decorator(target):
raise Exception("not a decoratable function")
spec = compat.inspect_getfullargspec(fn)
names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
- targ_name, fn_name = _unique_symbols(names, 'target', 'fn')
+ targ_name, fn_name = _unique_symbols(names, "target", "fn")
metadata = dict(target=targ_name, fn=fn_name)
metadata.update(format_argspec_plus(spec, grouped=False))
- metadata['name'] = fn.__name__
- code = """\
+ metadata["name"] = fn.__name__
+ code = (
+ """\
def %(name)s(%(args)s):
return %(target)s(%(fn)s, %(apply_kw)s)
-""" % metadata
- decorated = _exec_code_in_env(code,
- {targ_name: target, fn_name: fn},
- fn.__name__)
- decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__
+"""
+ % metadata
+ )
+ decorated = _exec_code_in_env(
+ code, {targ_name: target, fn_name: fn}, fn.__name__
+ )
+ decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__
decorated.__wrapped__ = fn
return update_wrapper(decorated, fn)
+
return update_wrapper(decorate, target)
@@ -155,31 +161,38 @@ def public_factory(target, location):
if isinstance(target, type):
fn = target.__init__
callable_ = target
- doc = "Construct a new :class:`.%s` object. \n\n"\
- "This constructor is mirrored as a public API function; "\
- "see :func:`~%s` "\
- "for a full usage and argument description." % (
- target.__name__, location, )
+ doc = (
+ "Construct a new :class:`.%s` object. \n\n"
+ "This constructor is mirrored as a public API function; "
+ "see :func:`~%s` "
+ "for a full usage and argument description."
+ % (target.__name__, location)
+ )
else:
fn = callable_ = target
- doc = "This function is mirrored; see :func:`~%s` "\
+ doc = (
+ "This function is mirrored; see :func:`~%s` "
"for a description of arguments." % location
+ )
location_name = location.split(".")[-1]
spec = compat.inspect_getfullargspec(fn)
del spec[0][0]
metadata = format_argspec_plus(spec, grouped=False)
- metadata['name'] = location_name
- code = """\
+ metadata["name"] = location_name
+ code = (
+ """\
def %(name)s(%(args)s):
return cls(%(apply_kw)s)
-""" % metadata
- env = {'cls': callable_, 'symbol': symbol}
+"""
+ % metadata
+ )
+ env = {"cls": callable_, "symbol": symbol}
exec(code, env)
decorated = env[location_name]
decorated.__doc__ = fn.__doc__
decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0]
- if compat.py2k or hasattr(fn, '__func__'):
+ if compat.py2k or hasattr(fn, "__func__"):
fn.__func__.__doc__ = doc
else:
fn.__doc__ = doc
@@ -187,7 +200,6 @@ def %(name)s(%(args)s):
class PluginLoader(object):
-
def __init__(self, group, auto_fn=None):
self.group = group
self.impls = {}
@@ -211,14 +223,13 @@ class PluginLoader(object):
except ImportError:
pass
else:
- for impl in pkg_resources.iter_entry_points(
- self.group, name):
+ for impl in pkg_resources.iter_entry_points(self.group, name):
self.impls[name] = impl.load
return impl.load()
raise exc.NoSuchModuleError(
- "Can't load plugin: %s:%s" %
- (self.group, name))
+ "Can't load plugin: %s:%s" % (self.group, name)
+ )
def register(self, name, modulepath, objname):
def load():
@@ -226,6 +237,7 @@ class PluginLoader(object):
for token in modulepath.split(".")[1:]:
mod = getattr(mod, token)
return getattr(mod, objname)
+
self.impls[name] = load
@@ -245,10 +257,13 @@ def get_cls_kwargs(cls, _set=None):
if toplevel:
_set = set()
- ctr = cls.__dict__.get('__init__', False)
+ ctr = cls.__dict__.get("__init__", False)
- has_init = ctr and isinstance(ctr, types.FunctionType) and \
- isinstance(ctr.__code__, types.CodeType)
+ has_init = (
+ ctr
+ and isinstance(ctr, types.FunctionType)
+ and isinstance(ctr.__code__, types.CodeType)
+ )
if has_init:
names, has_kw = inspect_func_args(ctr)
@@ -262,7 +277,7 @@ def get_cls_kwargs(cls, _set=None):
if get_cls_kwargs(c, _set) is None:
break
- _set.discard('self')
+ _set.discard("self")
return _set
@@ -278,7 +293,9 @@ try:
has_kw = bool(co.co_flags & CO_VARKEYWORDS)
return args, has_kw
+
except ImportError:
+
def inspect_func_args(fn):
names, _, has_kw, _ = compat.inspect_getargspec(fn)
return names, bool(has_kw)
@@ -309,23 +326,26 @@ def get_callable_argspec(fn, no_self=False, _is_init=False):
elif inspect.isfunction(fn):
if _is_init and no_self:
spec = compat.inspect_getargspec(fn)
- return compat.ArgSpec(spec.args[1:], spec.varargs,
- spec.keywords, spec.defaults)
+ return compat.ArgSpec(
+ spec.args[1:], spec.varargs, spec.keywords, spec.defaults
+ )
else:
return compat.inspect_getargspec(fn)
elif inspect.ismethod(fn):
if no_self and (_is_init or fn.__self__):
spec = compat.inspect_getargspec(fn.__func__)
- return compat.ArgSpec(spec.args[1:], spec.varargs,
- spec.keywords, spec.defaults)
+ return compat.ArgSpec(
+ spec.args[1:], spec.varargs, spec.keywords, spec.defaults
+ )
else:
return compat.inspect_getargspec(fn.__func__)
elif inspect.isclass(fn):
return get_callable_argspec(
- fn.__init__, no_self=no_self, _is_init=True)
- elif hasattr(fn, '__func__'):
+ fn.__init__, no_self=no_self, _is_init=True
+ )
+ elif hasattr(fn, "__func__"):
return compat.inspect_getargspec(fn.__func__)
- elif hasattr(fn, '__call__'):
+ elif hasattr(fn, "__call__"):
if inspect.ismethod(fn.__call__):
return get_callable_argspec(fn.__call__, no_self=no_self)
else:
@@ -375,13 +395,14 @@ def format_argspec_plus(fn, grouped=True):
if spec[0]:
self_arg = spec[0][0]
elif spec[1]:
- self_arg = '%s[0]' % spec[1]
+ self_arg = "%s[0]" % spec[1]
else:
self_arg = None
if compat.py3k:
apply_pos = compat.inspect_formatargspec(
- spec[0], spec[1], spec[2], None, spec[4])
+ spec[0], spec[1], spec[2], None, spec[4]
+ )
num_defaults = 0
if spec[3]:
num_defaults += len(spec[3])
@@ -396,19 +417,31 @@ def format_argspec_plus(fn, grouped=True):
name_args = spec[0]
if num_defaults:
- defaulted_vals = name_args[0 - num_defaults:]
+ defaulted_vals = name_args[0 - num_defaults :]
else:
defaulted_vals = ()
apply_kw = compat.inspect_formatargspec(
- name_args, spec[1], spec[2], defaulted_vals,
- formatvalue=lambda x: '=' + x)
+ name_args,
+ spec[1],
+ spec[2],
+ defaulted_vals,
+ formatvalue=lambda x: "=" + x,
+ )
if grouped:
- return dict(args=args, self_arg=self_arg,
- apply_pos=apply_pos, apply_kw=apply_kw)
+ return dict(
+ args=args,
+ self_arg=self_arg,
+ apply_pos=apply_pos,
+ apply_kw=apply_kw,
+ )
else:
- return dict(args=args[1:-1], self_arg=self_arg,
- apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1])
+ return dict(
+ args=args[1:-1],
+ self_arg=self_arg,
+ apply_pos=apply_pos[1:-1],
+ apply_kw=apply_kw[1:-1],
+ )
def format_argspec_init(method, grouped=True):
@@ -422,14 +455,17 @@ def format_argspec_init(method, grouped=True):
"""
if method is object.__init__:
- args = grouped and '(self)' or 'self'
+ args = grouped and "(self)" or "self"
else:
try:
return format_argspec_plus(method, grouped=grouped)
except TypeError:
- args = (grouped and '(self, *args, **kwargs)'
- or 'self, *args, **kwargs')
- return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args)
+ args = (
+ grouped
+ and "(self, *args, **kwargs)"
+ or "self, *args, **kwargs"
+ )
+ return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args)
def getargspec_init(method):
@@ -445,9 +481,9 @@ def getargspec_init(method):
return compat.inspect_getargspec(method)
except TypeError:
if method is object.__init__:
- return (['self'], None, None, None)
+ return (["self"], None, None, None)
else:
- return (['self'], 'args', 'kwargs', None)
+ return (["self"], "args", "kwargs", None)
def unbound_method_to_callable(func_or_cls):
@@ -479,8 +515,9 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
vargs = None
for i, insp in enumerate(to_inspect):
try:
- (_args, _vargs, vkw, defaults) = \
- compat.inspect_getargspec(insp.__init__)
+ (_args, _vargs, vkw, defaults) = compat.inspect_getargspec(
+ insp.__init__
+ )
except TypeError:
continue
else:
@@ -493,16 +530,17 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
else:
pos_args.extend(_args[1:])
else:
- kw_args.update([
- (arg, missing) for arg in _args[1:-default_len]
- ])
+ kw_args.update(
+ [(arg, missing) for arg in _args[1:-default_len]]
+ )
if default_len:
- kw_args.update([
- (arg, default)
- for arg, default
- in zip(_args[-default_len:], defaults)
- ])
+ kw_args.update(
+ [
+ (arg, default)
+ for arg, default in zip(_args[-default_len:], defaults)
+ ]
+ )
output = []
output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
@@ -516,7 +554,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
try:
val = getattr(obj, arg, missing)
if val is not missing and val != defval:
- output.append('%s=%r' % (arg, val))
+ output.append("%s=%r" % (arg, val))
except Exception:
pass
@@ -525,7 +563,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
try:
val = getattr(obj, arg, missing)
if val is not missing and val != defval:
- output.append('%s=%r' % (arg, val))
+ output.append("%s=%r" % (arg, val))
except Exception:
pass
@@ -538,16 +576,19 @@ class portable_instancemethod(object):
"""
- __slots__ = 'target', 'name', 'kwargs', '__weakref__'
+ __slots__ = "target", "name", "kwargs", "__weakref__"
def __getstate__(self):
- return {'target': self.target, 'name': self.name,
- 'kwargs': self.kwargs}
+ return {
+ "target": self.target,
+ "name": self.name,
+ "kwargs": self.kwargs,
+ }
def __setstate__(self, state):
- self.target = state['target']
- self.name = state['name']
- self.kwargs = state.get('kwargs', ())
+ self.target = state["target"]
+ self.name = state["name"]
+ self.kwargs = state.get("kwargs", ())
def __init__(self, meth, kwargs=()):
self.target = meth.__self__
@@ -583,8 +624,11 @@ def class_hierarchy(cls):
if compat.py2k:
if isinstance(c, types.ClassType):
continue
- bases = (_ for _ in c.__bases__
- if _ not in hier and not isinstance(_, types.ClassType))
+ bases = (
+ _
+ for _ in c.__bases__
+ if _ not in hier and not isinstance(_, types.ClassType)
+ )
else:
bases = (_ for _ in c.__bases__ if _ not in hier)
@@ -593,11 +637,12 @@ def class_hierarchy(cls):
hier.add(b)
if compat.py3k:
- if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'):
+ if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
continue
else:
- if c.__module__ == '__builtin__' or not hasattr(
- c, '__subclasses__'):
+ if c.__module__ == "__builtin__" or not hasattr(
+ c, "__subclasses__"
+ ):
continue
for s in [_ for _ in c.__subclasses__() if _ not in hier]:
@@ -622,26 +667,45 @@ def iterate_attributes(cls):
break
-def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None,
- name='self.proxy', from_instance=None):
+def monkeypatch_proxied_specials(
+ into_cls,
+ from_cls,
+ skip=None,
+ only=None,
+ name="self.proxy",
+ from_instance=None,
+):
"""Automates delegation of __specials__ for a proxying type."""
if only:
dunders = only
else:
if skip is None:
- skip = ('__slots__', '__del__', '__getattribute__',
- '__metaclass__', '__getstate__', '__setstate__')
- dunders = [m for m in dir(from_cls)
- if (m.startswith('__') and m.endswith('__') and
- not hasattr(into_cls, m) and m not in skip)]
+ skip = (
+ "__slots__",
+ "__del__",
+ "__getattribute__",
+ "__metaclass__",
+ "__getstate__",
+ "__setstate__",
+ )
+ dunders = [
+ m
+ for m in dir(from_cls)
+ if (
+ m.startswith("__")
+ and m.endswith("__")
+ and not hasattr(into_cls, m)
+ and m not in skip
+ )
+ ]
for method in dunders:
try:
fn = getattr(from_cls, method)
- if not hasattr(fn, '__call__'):
+ if not hasattr(fn, "__call__"):
continue
- fn = getattr(fn, 'im_func', fn)
+ fn = getattr(fn, "im_func", fn)
except AttributeError:
continue
try:
@@ -649,11 +713,13 @@ def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None,
fn_args = compat.inspect_formatargspec(spec[0])
d_args = compat.inspect_formatargspec(spec[0][1:])
except TypeError:
- fn_args = '(self, *args, **kw)'
- d_args = '(*args, **kw)'
+ fn_args = "(self, *args, **kw)"
+ d_args = "(*args, **kw)"
- py = ("def %(method)s%(fn_args)s: "
- "return %(name)s.%(method)s%(d_args)s" % locals())
+ py = (
+ "def %(method)s%(fn_args)s: "
+ "return %(name)s.%(method)s%(d_args)s" % locals()
+ )
env = from_instance is not None and {name: from_instance} or {}
compat.exec_(py, env)
@@ -667,8 +733,9 @@ def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None,
def methods_equivalent(meth1, meth2):
"""Return True if the two methods are the same implementation."""
- return getattr(meth1, '__func__', meth1) is getattr(
- meth2, '__func__', meth2)
+ return getattr(meth1, "__func__", meth1) is getattr(
+ meth2, "__func__", meth2
+ )
def as_interface(obj, cls=None, methods=None, required=None):
@@ -705,12 +772,12 @@ def as_interface(obj, cls=None, methods=None, required=None):
"""
if not cls and not methods:
- raise TypeError('a class or collection of method names are required')
+ raise TypeError("a class or collection of method names are required")
if isinstance(cls, type) and isinstance(obj, cls):
return obj
- interface = set(methods or [m for m in dir(cls) if not m.startswith('_')])
+ interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
implemented = set(dir(obj))
complies = operator.ge
@@ -727,15 +794,17 @@ def as_interface(obj, cls=None, methods=None, required=None):
# No dict duck typing here.
if not isinstance(obj, dict):
- qualifier = complies is operator.gt and 'any of' or 'all of'
- raise TypeError("%r does not implement %s: %s" % (
- obj, qualifier, ', '.join(interface)))
+ qualifier = complies is operator.gt and "any of" or "all of"
+ raise TypeError(
+ "%r does not implement %s: %s"
+ % (obj, qualifier, ", ".join(interface))
+ )
class AnonymousInterface(object):
"""A callable-holding shell."""
if cls:
- AnonymousInterface.__name__ = 'Anonymous' + cls.__name__
+ AnonymousInterface.__name__ = "Anonymous" + cls.__name__
found = set()
for method, impl in dictlike_iteritems(obj):
@@ -749,8 +818,10 @@ def as_interface(obj, cls=None, methods=None, required=None):
if complies(found, required):
return AnonymousInterface
- raise TypeError("dictionary does not contain required keys %s" %
- ', '.join(required - found))
+ raise TypeError(
+ "dictionary does not contain required keys %s"
+ % ", ".join(required - found)
+ )
class memoized_property(object):
@@ -791,6 +862,7 @@ def memoized_instancemethod(fn):
memo.__doc__ = fn.__doc__
self.__dict__[fn.__name__] = memo
return result
+
return update_wrapper(oneshot, fn)
@@ -831,14 +903,14 @@ class MemoizedSlots(object):
raise AttributeError(key)
def __getattr__(self, key):
- if key.startswith('_memoized'):
+ if key.startswith("_memoized"):
raise AttributeError(key)
- elif hasattr(self, '_memoized_attr_%s' % key):
- value = getattr(self, '_memoized_attr_%s' % key)()
+ elif hasattr(self, "_memoized_attr_%s" % key):
+ value = getattr(self, "_memoized_attr_%s" % key)()
setattr(self, key, value)
return value
- elif hasattr(self, '_memoized_method_%s' % key):
- fn = getattr(self, '_memoized_method_%s' % key)
+ elif hasattr(self, "_memoized_method_%s" % key):
+ fn = getattr(self, "_memoized_method_%s" % key)
def oneshot(*args, **kw):
result = fn(*args, **kw)
@@ -847,6 +919,7 @@ class MemoizedSlots(object):
memo.__doc__ = fn.__doc__
setattr(self, key, memo)
return result
+
oneshot.__doc__ = fn.__doc__
return oneshot
else:
@@ -859,12 +932,14 @@ def dependency_for(modulename, add_to_all=False):
# unfortunately importlib doesn't work that great either
tokens = modulename.split(".")
mod = compat.import_(
- ".".join(tokens[0:-1]), globals(), locals(), [tokens[-1]])
+ ".".join(tokens[0:-1]), globals(), locals(), [tokens[-1]]
+ )
mod = getattr(mod, tokens[-1])
setattr(mod, obj.__name__, obj)
if add_to_all and hasattr(mod, "__all__"):
mod.__all__.append(obj.__name__)
return obj
+
return decorate
@@ -891,10 +966,7 @@ class dependencies(object):
for dep in deps:
tokens = dep.split(".")
self.import_deps.append(
- dependencies._importlater(
- ".".join(tokens[0:-1]),
- tokens[-1]
- )
+ dependencies._importlater(".".join(tokens[0:-1]), tokens[-1])
)
def __call__(self, fn):
@@ -902,7 +974,7 @@ class dependencies(object):
spec = compat.inspect_getfullargspec(fn)
spec_zero = list(spec[0])
- hasself = spec_zero[0] in ('self', 'cls')
+ hasself = spec_zero[0] in ("self", "cls")
for i in range(len(import_deps)):
spec[0][i + (1 if hasself else 0)] = "import_deps[%r]" % i
@@ -915,13 +987,13 @@ class dependencies(object):
outer_spec = format_argspec_plus(spec, grouped=False)
- code = 'lambda %(args)s: fn(%(apply_kw)s)' % {
- "args": outer_spec['args'],
- "apply_kw": inner_spec['apply_kw']
+ code = "lambda %(args)s: fn(%(apply_kw)s)" % {
+ "args": outer_spec["args"],
+ "apply_kw": inner_spec["apply_kw"],
}
decorated = eval(code, locals())
- decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__
+ decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__
return update_wrapper(decorated, fn)
@classmethod
@@ -961,26 +1033,27 @@ class dependencies(object):
raise ImportError(
"importlater.resolve_all() hasn't "
"been called (this is %s %s)"
- % (self._il_path, self._il_addtl))
+ % (self._il_path, self._il_addtl)
+ )
return getattr(self._initial_import, self._il_addtl)
def _resolve(self):
dependencies._unresolved.discard(self)
self._initial_import = compat.import_(
- self._il_path, globals(), locals(),
- [self._il_addtl])
+ self._il_path, globals(), locals(), [self._il_addtl]
+ )
def __getattr__(self, key):
- if key == 'module':
- raise ImportError("Could not resolve module %s"
- % self._full_path)
+ if key == "module":
+ raise ImportError(
+ "Could not resolve module %s" % self._full_path
+ )
try:
attr = getattr(self.module, key)
except AttributeError:
raise AttributeError(
- "Module %s has no attribute '%s'" %
- (self._full_path, key)
+ "Module %s has no attribute '%s'" % (self._full_path, key)
)
self.__dict__[key] = attr
return attr
@@ -990,9 +1063,9 @@ class dependencies(object):
def asbool(obj):
if isinstance(obj, compat.string_types):
obj = obj.strip().lower()
- if obj in ['true', 'yes', 'on', 'y', 't', '1']:
+ if obj in ["true", "yes", "on", "y", "t", "1"]:
return True
- elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
+ elif obj in ["false", "no", "off", "n", "f", "0"]:
return False
else:
raise ValueError("String is not true/false: %r" % obj)
@@ -1004,11 +1077,13 @@ def bool_or_str(*text):
boolean, or one of a set of "alternate" string values.
"""
+
def bool_or_value(obj):
if obj in text:
return obj
else:
return asbool(obj)
+
return bool_or_value
@@ -1026,9 +1101,11 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True):
when coercing to boolean.
"""
- if key in kw and (
- not isinstance(type_, type) or not isinstance(kw[key], type_)
- ) and kw[key] is not None:
+ if (
+ key in kw
+ and (not isinstance(type_, type) or not isinstance(kw[key], type_))
+ and kw[key] is not None
+ ):
if type_ is bool and flexi_bool:
kw[key] = asbool(kw[key])
else:
@@ -1044,8 +1121,8 @@ def constructor_copy(obj, cls, *args, **kw):
names = get_cls_kwargs(cls)
kw.update(
- (k, obj.__dict__[k]) for k in names.difference(kw)
- if k in obj.__dict__)
+ (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
+ )
return cls(*args, **kw)
@@ -1072,10 +1149,11 @@ def duck_type_collection(specimen, default=None):
property is present, return that preferentially.
"""
- if hasattr(specimen, '__emulates__'):
+ if hasattr(specimen, "__emulates__"):
# canonicalize set vs sets.Set to a standard: the builtin set
- if (specimen.__emulates__ is not None and
- issubclass(specimen.__emulates__, set)):
+ if specimen.__emulates__ is not None and issubclass(
+ specimen.__emulates__, set
+ ):
return set
else:
return specimen.__emulates__
@@ -1088,11 +1166,11 @@ def duck_type_collection(specimen, default=None):
elif isa(specimen, dict):
return dict
- if hasattr(specimen, 'append'):
+ if hasattr(specimen, "append"):
return list
- elif hasattr(specimen, 'add'):
+ elif hasattr(specimen, "add"):
return set
- elif hasattr(specimen, 'set'):
+ elif hasattr(specimen, "set"):
return dict
else:
return default
@@ -1104,41 +1182,43 @@ def assert_arg_type(arg, argtype, name):
else:
if isinstance(argtype, tuple):
raise exc.ArgumentError(
- "Argument '%s' is expected to be one of type %s, got '%s'" %
- (name, ' or '.join("'%s'" % a for a in argtype), type(arg)))
+ "Argument '%s' is expected to be one of type %s, got '%s'"
+ % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
+ )
else:
raise exc.ArgumentError(
- "Argument '%s' is expected to be of type '%s', got '%s'" %
- (name, argtype, type(arg)))
+ "Argument '%s' is expected to be of type '%s', got '%s'"
+ % (name, argtype, type(arg))
+ )
def dictlike_iteritems(dictlike):
"""Return a (key, value) iterator for almost any dict-like object."""
if compat.py3k:
- if hasattr(dictlike, 'items'):
+ if hasattr(dictlike, "items"):
return list(dictlike.items())
else:
- if hasattr(dictlike, 'iteritems'):
+ if hasattr(dictlike, "iteritems"):
return dictlike.iteritems()
- elif hasattr(dictlike, 'items'):
+ elif hasattr(dictlike, "items"):
return iter(dictlike.items())
- getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None))
+ getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
if getter is None:
- raise TypeError(
- "Object '%r' is not dict-like" % dictlike)
+ raise TypeError("Object '%r' is not dict-like" % dictlike)
+
+ if hasattr(dictlike, "iterkeys"):
- if hasattr(dictlike, 'iterkeys'):
def iterator():
for key in dictlike.iterkeys():
yield key, getter(key)
+
return iterator()
- elif hasattr(dictlike, 'keys'):
+ elif hasattr(dictlike, "keys"):
return iter((key, getter(key)) for key in dictlike.keys())
else:
- raise TypeError(
- "Object '%r' is not dict-like" % dictlike)
+ raise TypeError("Object '%r' is not dict-like" % dictlike)
class classproperty(property):
@@ -1207,7 +1287,8 @@ class _symbol(int):
def __repr__(self):
return "symbol(%r)" % self.name
-_symbol.__name__ = 'symbol'
+
+_symbol.__name__ = "symbol"
class symbol(object):
@@ -1231,6 +1312,7 @@ class symbol(object):
``doc`` here.
"""
+
symbols = {}
_lock = compat.threading.Lock()
@@ -1292,9 +1374,11 @@ class _hash_limit_string(compat.text_type):
"""
+
def __new__(cls, value, num, args):
- interpolated = (value % args) + \
- (" (this warning may be suppressed after %d occurrences)" % num)
+ interpolated = (value % args) + (
+ " (this warning may be suppressed after %d occurrences)" % num
+ )
self = super(_hash_limit_string, cls).__new__(cls, interpolated)
self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
return self
@@ -1340,8 +1424,8 @@ def only_once(fn):
return go
-_SQLA_RE = re.compile(r'sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py')
-_UNITTEST_RE = re.compile(r'unit(?:2|test2?/)')
+_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
+_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
@@ -1363,18 +1447,17 @@ def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
start += 1
while start <= end and exclude_suffix.search(tb[end]):
end -= 1
- return tb[start:end + 1]
+ return tb[start : end + 1]
+
NoneType = type(None)
def attrsetter(attrname):
- code = \
- "def set(obj, value):"\
- " obj.%s = value" % attrname
+ code = "def set(obj, value):" " obj.%s = value" % attrname
env = locals().copy()
exec(code, env)
- return env['set']
+ return env["set"]
class EnsureKWArgType(type):
@@ -1382,6 +1465,7 @@ class EnsureKWArgType(type):
don't already.
"""
+
def __init__(cls, clsname, bases, clsdict):
fn_reg = cls.ensure_kwarg
if fn_reg:
@@ -1396,9 +1480,9 @@ class EnsureKWArgType(type):
super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
def _wrap_w_kw(self, fn):
-
def wrap(*arg, **kw):
return fn(*arg)
+
return update_wrapper(wrap, fn)
@@ -1410,15 +1494,15 @@ def wrap_callable(wrapper, fn):
object with __call__ method
"""
- if hasattr(fn, '__name__'):
+ if hasattr(fn, "__name__"):
return update_wrapper(wrapper, fn)
else:
_f = wrapper
_f.__name__ = fn.__class__.__name__
- if hasattr(fn, '__module__'):
+ if hasattr(fn, "__module__"):
_f.__module__ = fn.__module__
- if hasattr(fn.__call__, '__doc__') and fn.__call__.__doc__:
+ if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
_f.__doc__ = fn.__call__.__doc__
elif fn.__doc__:
_f.__doc__ = fn.__doc__
@@ -1468,4 +1552,3 @@ def quoted_token_parser(value):
idx += 1
return ["".join(token) for token in result]
-
diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py
index 640f70ea9..5e56e855a 100644
--- a/lib/sqlalchemy/util/queue.py
+++ b/lib/sqlalchemy/util/queue.py
@@ -23,7 +23,7 @@ from time import time as _time
from .compat import threading
-__all__ = ['Empty', 'Full', 'Queue']
+__all__ = ["Empty", "Full", "Queue"]
class Empty(Exception):
diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py
index 5f516d67e..95391c31b 100644
--- a/lib/sqlalchemy/util/topological.py
+++ b/lib/sqlalchemy/util/topological.py
@@ -10,7 +10,7 @@
from ..exc import CircularDependencyError
from .. import util
-__all__ = ['sort', 'sort_as_subsets', 'find_cycles']
+__all__ = ["sort", "sort_as_subsets", "find_cycles"]
def sort_as_subsets(tuples, allitems, deterministic_order=False):
@@ -33,7 +33,7 @@ def sort_as_subsets(tuples, allitems, deterministic_order=False):
raise CircularDependencyError(
"Circular dependency detected.",
find_cycles(tuples, allitems),
- _gen_edges(edges)
+ _gen_edges(edges),
)
todo.difference_update(output)
@@ -79,7 +79,7 @@ def find_cycles(tuples, allitems):
top = stack[-1]
for node in edges[top]:
if node in stack:
- cyc = stack[stack.index(node):]
+ cyc = stack[stack.index(node) :]
todo.difference_update(cyc)
output.update(cyc)
@@ -93,8 +93,4 @@ def find_cycles(tuples, allitems):
def _gen_edges(edges):
- return set([
- (right, left)
- for left in edges
- for right in edges[left]
- ])
+ return set([(right, left) for left in edges for right in edges[left]])