diff options
Diffstat (limited to 'lib/sqlalchemy')
102 files changed, 1114 insertions, 1338 deletions
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 458663bb7..f920860cd 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -51,7 +51,7 @@ class PyODBCConnector(Connector): dbapi: ModuleType def __init__(self, use_setinputsizes: bool = False, **kw: Any): - super(PyODBCConnector, self).__init__(**kw) + super().__init__(**kw) if use_setinputsizes: self.bind_typing = interfaces.BindTyping.SETINPUTSIZES @@ -83,7 +83,7 @@ class PyODBCConnector(Connector): token = "{%s}" % token.replace("}", "}}") return token - keys = dict((k, check_quote(v)) for k, v in keys.items()) + keys = {k: check_quote(v) for k, v in keys.items()} dsn_connection = "dsn" in keys or ( "host" in keys and "database" not in keys diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 53fe96c9a..a0049c361 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -965,190 +965,188 @@ MS_2008_VERSION = (10,) MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) -RESERVED_WORDS = set( - [ - "add", - "all", - "alter", - "and", - "any", - "as", - "asc", - "authorization", - "backup", - "begin", - "between", - "break", - "browse", - "bulk", - "by", - "cascade", - "case", - "check", - "checkpoint", - "close", - "clustered", - "coalesce", - "collate", - "column", - "commit", - "compute", - "constraint", - "contains", - "containstable", - "continue", - "convert", - "create", - "cross", - "current", - "current_date", - "current_time", - "current_timestamp", - "current_user", - "cursor", - "database", - "dbcc", - "deallocate", - "declare", - "default", - "delete", - "deny", - "desc", - "disk", - "distinct", - "distributed", - "double", - "drop", - "dump", - "else", - "end", - "errlvl", - "escape", - "except", - "exec", - "execute", - "exists", - "exit", - "external", - "fetch", - "file", - "fillfactor", - "for", - "foreign", - "freetext", - "freetexttable", - "from", - "full", - "function", - "goto", - "grant", - "group", - "having", - "holdlock", - "identity", - "identity_insert", - "identitycol", - "if", - "in", - "index", - "inner", - "insert", - "intersect", - "into", - "is", - "join", - "key", - "kill", - "left", - "like", - "lineno", - "load", - "merge", - "national", - "nocheck", - "nonclustered", - "not", - "null", - "nullif", - "of", - "off", - "offsets", - "on", - "open", - "opendatasource", - "openquery", - "openrowset", - "openxml", - "option", - "or", - "order", - "outer", - "over", - "percent", - "pivot", - "plan", - "precision", - "primary", - "print", - "proc", - "procedure", - "public", - "raiserror", - "read", - "readtext", - "reconfigure", - "references", - "replication", - "restore", - "restrict", - "return", - "revert", - "revoke", - "right", - "rollback", - "rowcount", - "rowguidcol", - "rule", - "save", - "schema", - "securityaudit", - "select", - "session_user", - "set", - "setuser", - "shutdown", - "some", - "statistics", - "system_user", - "table", - "tablesample", - "textsize", - "then", - "to", - "top", - "tran", - "transaction", - "trigger", - "truncate", - "tsequal", - "union", - "unique", - "unpivot", - "update", - "updatetext", - "use", - "user", - "values", - "varying", - "view", - "waitfor", - "when", - "where", - "while", - "with", - "writetext", - ] -) +RESERVED_WORDS = { + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", +} class REAL(sqltypes.REAL): @@ -1159,7 +1157,7 @@ class REAL(sqltypes.REAL): # it is only accepted as the word "REAL" in DDL, the numeric # precision value is not allowed to be present kw.setdefault("precision", 24) - super(REAL, self).__init__(**kw) + super().__init__(**kw) class TINYINT(sqltypes.Integer): @@ -1204,7 +1202,7 @@ class _MSDate(sqltypes.Date): class TIME(sqltypes.TIME): def __init__(self, precision=None, **kwargs): self.precision = precision - super(TIME, self).__init__() + super().__init__() __zero_date = datetime.date(1900, 1, 1) @@ -1273,7 +1271,7 @@ class DATETIME2(_DateTimeBase, sqltypes.DateTime): __visit_name__ = "DATETIME2" def __init__(self, precision=None, **kw): - super(DATETIME2, self).__init__(**kw) + super().__init__(**kw) self.precision = precision @@ -1281,7 +1279,7 @@ class DATETIMEOFFSET(_DateTimeBase, sqltypes.DateTime): __visit_name__ = "DATETIMEOFFSET" def __init__(self, precision=None, **kw): - super(DATETIMEOFFSET, self).__init__(**kw) + super().__init__(**kw) self.precision = precision @@ -1339,7 +1337,7 @@ class TIMESTAMP(sqltypes._Binary): self.convert_int = convert_int def result_processor(self, dialect, coltype): - super_ = super(TIMESTAMP, self).result_processor(dialect, coltype) + super_ = super().result_processor(dialect, coltype) if self.convert_int: def process(value): @@ -1425,7 +1423,7 @@ class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): raise ValueError( "length must be None or 'max' when setting filestream" ) - super(VARBINARY, self).__init__(length=length) + super().__init__(length=length) class IMAGE(sqltypes.LargeBinary): @@ -1525,12 +1523,12 @@ class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): @overload def __init__( - self: "UNIQUEIDENTIFIER[_python_UUID]", as_uuid: Literal[True] = ... + self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ... ): ... @overload - def __init__(self: "UNIQUEIDENTIFIER[str]", as_uuid: Literal[False] = ...): + def __init__(self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...): ... def __init__(self, as_uuid: bool = True): @@ -1972,7 +1970,7 @@ class MSExecutionContext(default.DefaultExecutionContext): and column.default.optional ): return None - return super(MSExecutionContext, self).get_insert_default(column) + return super().get_insert_default(column) class MSSQLCompiler(compiler.SQLCompiler): @@ -1990,7 +1988,7 @@ class MSSQLCompiler(compiler.SQLCompiler): def __init__(self, *args, **kwargs): self.tablealiases = {} - super(MSSQLCompiler, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def _with_legacy_schema_aliasing(fn): def decorate(self, *arg, **kw): @@ -2040,7 +2038,7 @@ class MSSQLCompiler(compiler.SQLCompiler): def get_select_precolumns(self, select, **kw): """MS-SQL puts TOP, it's version of LIMIT here""" - s = super(MSSQLCompiler, self).get_select_precolumns(select, **kw) + s = super().get_select_precolumns(select, **kw) if select._has_row_limiting_clause and self._use_top(select): # ODBC drivers and possibly others @@ -2186,20 +2184,20 @@ class MSSQLCompiler(compiler.SQLCompiler): @_with_legacy_schema_aliasing def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs): if mssql_aliased is table or iscrud: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) + return super().visit_table(table, **kwargs) # alias schema-qualified tables alias = self._schema_aliased_table(table) if alias is not None: return self.process(alias, mssql_aliased=table, **kwargs) else: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) + return super().visit_table(table, **kwargs) @_with_legacy_schema_aliasing def visit_alias(self, alias, **kw): # translate for schema-qualified table aliases kw["mssql_aliased"] = alias.element - return super(MSSQLCompiler, self).visit_alias(alias, **kw) + return super().visit_alias(alias, **kw) @_with_legacy_schema_aliasing def visit_column(self, column, add_to_result_map=None, **kw): @@ -2220,9 +2218,9 @@ class MSSQLCompiler(compiler.SQLCompiler): column.type, ) - return super(MSSQLCompiler, self).visit_column(converted, **kw) + return super().visit_column(converted, **kw) - return super(MSSQLCompiler, self).visit_column( + return super().visit_column( column, add_to_result_map=add_to_result_map, **kw ) @@ -2264,7 +2262,7 @@ class MSSQLCompiler(compiler.SQLCompiler): ), **kwargs, ) - return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + return super().visit_binary(binary, **kwargs) def returning_clause( self, stmt, returning_cols, *, populate_result_map, **kw @@ -2328,9 +2326,7 @@ class MSSQLCompiler(compiler.SQLCompiler): if isinstance(column, expression.Function): return column.label(None) else: - return super(MSSQLCompiler, self).label_select_column( - select, column, asfrom - ) + return super().label_select_column(select, column, asfrom) def for_update_clause(self, select, **kw): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which @@ -2517,9 +2513,7 @@ class MSSQLStrictCompiler(MSSQLCompiler): # SQL Server wants single quotes around the date string. return "'" + str(value) + "'" else: - return super(MSSQLStrictCompiler, self).render_literal_value( - value, type_ - ) + return super().render_literal_value(value, type_) class MSDDLCompiler(compiler.DDLCompiler): @@ -2704,7 +2698,7 @@ class MSDDLCompiler(compiler.DDLCompiler): schema_name = schema if schema else self.dialect.default_schema_name return ( "execute sp_addextendedproperty 'MS_Description', " - "{0}, 'schema', {1}, 'table', {2}".format( + "{}, 'schema', {}, 'table', {}".format( self.sql_compiler.render_literal_value( create.element.comment, sqltypes.NVARCHAR() ), @@ -2718,7 +2712,7 @@ class MSDDLCompiler(compiler.DDLCompiler): schema_name = schema if schema else self.dialect.default_schema_name return ( "execute sp_dropextendedproperty 'MS_Description', 'schema', " - "{0}, 'table', {1}".format( + "{}, 'table', {}".format( self.preparer.quote_schema(schema_name), self.preparer.format_table(drop.element, use_schema=False), ) @@ -2729,7 +2723,7 @@ class MSDDLCompiler(compiler.DDLCompiler): schema_name = schema if schema else self.dialect.default_schema_name return ( "execute sp_addextendedproperty 'MS_Description', " - "{0}, 'schema', {1}, 'table', {2}, 'column', {3}".format( + "{}, 'schema', {}, 'table', {}, 'column', {}".format( self.sql_compiler.render_literal_value( create.element.comment, sqltypes.NVARCHAR() ), @@ -2746,7 +2740,7 @@ class MSDDLCompiler(compiler.DDLCompiler): schema_name = schema if schema else self.dialect.default_schema_name return ( "execute sp_dropextendedproperty 'MS_Description', 'schema', " - "{0}, 'table', {1}, 'column', {2}".format( + "{}, 'table', {}, 'column', {}".format( self.preparer.quote_schema(schema_name), self.preparer.format_table( drop.element.table, use_schema=False @@ -2760,9 +2754,7 @@ class MSDDLCompiler(compiler.DDLCompiler): if create.element.data_type is not None: data_type = create.element.data_type prefix = " AS %s" % self.type_compiler.process(data_type) - return super(MSDDLCompiler, self).visit_create_sequence( - create, prefix=prefix, **kw - ) + return super().visit_create_sequence(create, prefix=prefix, **kw) def visit_identity_column(self, identity, **kw): text = " IDENTITY" @@ -2777,7 +2769,7 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS def __init__(self, dialect): - super(MSIdentifierPreparer, self).__init__( + super().__init__( dialect, initial_quote="[", final_quote="]", @@ -3067,7 +3059,7 @@ class MSDialect(default.DefaultDialect): ) self.legacy_schema_aliasing = legacy_schema_aliasing - super(MSDialect, self).__init__(**opts) + super().__init__(**opts) self._json_serializer = json_serializer self._json_deserializer = json_deserializer @@ -3075,7 +3067,7 @@ class MSDialect(default.DefaultDialect): def do_savepoint(self, connection, name): # give the DBAPI a push connection.exec_driver_sql("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") - super(MSDialect, self).do_savepoint(connection, name) + super().do_savepoint(connection, name) def do_release_savepoint(self, connection, name): # SQL Server does not support RELEASE SAVEPOINT @@ -3083,7 +3075,7 @@ class MSDialect(default.DefaultDialect): def do_rollback(self, dbapi_connection): try: - super(MSDialect, self).do_rollback(dbapi_connection) + super().do_rollback(dbapi_connection) except self.dbapi.ProgrammingError as e: if self.ignore_no_transaction_on_rollback and re.match( r".*\b111214\b", str(e) @@ -3097,15 +3089,13 @@ class MSDialect(default.DefaultDialect): else: raise - _isolation_lookup = set( - [ - "SERIALIZABLE", - "READ UNCOMMITTED", - "READ COMMITTED", - "REPEATABLE READ", - "SNAPSHOT", - ] - ) + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + } def get_isolation_level_values(self, dbapi_connection): return list(self._isolation_lookup) @@ -3134,7 +3124,7 @@ class MSDialect(default.DefaultDialect): "SQL Server version." ) - view_name = "sys.{}".format(row[0]) + view_name = f"sys.{row[0]}" cursor.execute( """ @@ -3164,7 +3154,7 @@ class MSDialect(default.DefaultDialect): cursor.close() def initialize(self, connection): - super(MSDialect, self).initialize(connection) + super().initialize(connection) self._setup_version_attributes() self._setup_supports_nvarchar_max(connection) @@ -3298,7 +3288,7 @@ class MSDialect(default.DefaultDialect): connection.scalar( # U filters on user tables only. text("SELECT object_id(:table_name, 'U')"), - {"table_name": "tempdb.dbo.[{}]".format(tablename)}, + {"table_name": f"tempdb.dbo.[{tablename}]"}, ) ) else: diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 96d03a908..5d859765c 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -42,7 +42,7 @@ class _MSNumeric_pymssql(sqltypes.Numeric): class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): def __init__(self, dialect): - super(MSIdentifierPreparer_pymssql, self).__init__(dialect) + super().__init__(dialect) # pymssql has the very unusual behavior that it uses pyformat # yet does not require that percent signs be doubled self._double_percents = False @@ -119,9 +119,7 @@ class MSDialect_pymssql(MSDialect): dbapi_connection.autocommit(True) else: dbapi_connection.autocommit(False) - super(MSDialect_pymssql, self).set_isolation_level( - dbapi_connection, level - ) + super().set_isolation_level(dbapi_connection, level) dialect = MSDialect_pymssql diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 3b8caef3b..07cbe3a73 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -384,7 +384,7 @@ class _ms_numeric_pyodbc: def bind_processor(self, dialect): - super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect) + super_process = super().bind_processor(dialect) if not dialect._need_decimal_fix: return super_process @@ -570,7 +570,7 @@ class MSExecutionContext_pyodbc(MSExecutionContext): """ - super(MSExecutionContext_pyodbc, self).pre_exec() + super().pre_exec() # don't embed the scope_identity select into an # "INSERT .. DEFAULT VALUES" @@ -601,7 +601,7 @@ class MSExecutionContext_pyodbc(MSExecutionContext): self._lastrowid = int(row[0]) else: - super(MSExecutionContext_pyodbc, self).post_exec() + super().post_exec() class MSDialect_pyodbc(PyODBCConnector, MSDialect): @@ -648,9 +648,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): use_setinputsizes=True, **params, ): - super(MSDialect_pyodbc, self).__init__( - use_setinputsizes=use_setinputsizes, **params - ) + super().__init__(use_setinputsizes=use_setinputsizes, **params) self.use_scope_identity = ( self.use_scope_identity and self.dbapi @@ -674,9 +672,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): # SQL Server docs indicate this function isn't present prior to # 2008. Before we had the VARCHAR cast above, pyodbc would also # fail on this query. - return super(MSDialect_pyodbc, self)._get_server_version_info( - connection - ) + return super()._get_server_version_info(connection) else: version = [] r = re.compile(r"[.\-]") @@ -688,7 +684,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): return tuple(version) def on_connect(self): - super_ = super(MSDialect_pyodbc, self).on_connect() + super_ = super().on_connect() def on_connect(conn): if super_ is not None: @@ -723,9 +719,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): def do_executemany(self, cursor, statement, parameters, context=None): if self.fast_executemany: cursor.fast_executemany = True - super(MSDialect_pyodbc, self).do_executemany( - cursor, statement, parameters, context=context - ) + super().do_executemany(cursor, statement, parameters, context=context) def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.Error): @@ -743,9 +737,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): "10054", }: return True - return super(MSDialect_pyodbc, self).is_disconnect( - e, connection, cursor - ) + return super().is_disconnect(e, connection, cursor) dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index 896c90227..79f865cf1 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -294,14 +294,12 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): return pool.AsyncAdaptedQueuePool def create_connect_args(self, url): - return super(MySQLDialect_aiomysql, self).create_connect_args( + return super().create_connect_args( url, _translate_args=dict(username="user", database="db") ) def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_aiomysql, self).is_disconnect( - e, connection, cursor - ): + if super().is_disconnect(e, connection, cursor): return True else: str_e = str(e).lower() diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index c8f29a2f1..df8965cbb 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -304,14 +304,12 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): return pool.AsyncAdaptedQueuePool def create_connect_args(self, url): - return super(MySQLDialect_asyncmy, self).create_connect_args( + return super().create_connect_args( url, _translate_args=dict(username="user", database="db") ) def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_asyncmy, self).is_disconnect( - e, connection, cursor - ): + if super().is_disconnect(e, connection, cursor): return True else: str_e = str(e).lower() diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e8ddb6d1e..2525c6c32 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1354,7 +1354,7 @@ class MySQLCompiler(compiler.SQLCompiler): name_text = self.preparer.quote(column.name) clauses.append("%s = %s" % (name_text, value_text)) - non_matching = set(on_duplicate.update) - set(c.key for c in cols) + non_matching = set(on_duplicate.update) - {c.key for c in cols} if non_matching: util.warn( "Additional column names not matching " @@ -1503,7 +1503,7 @@ class MySQLCompiler(compiler.SQLCompiler): return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) def render_literal_value(self, value, type_): - value = super(MySQLCompiler, self).render_literal_value(value, type_) + value = super().render_literal_value(value, type_) if self.dialect._backslash_escapes: value = value.replace("\\", "\\\\") return value @@ -1534,7 +1534,7 @@ class MySQLCompiler(compiler.SQLCompiler): ) return select._distinct.upper() + " " - return super(MySQLCompiler, self).get_select_precolumns(select, **kw) + return super().get_select_precolumns(select, **kw) def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): if from_linter: @@ -1805,11 +1805,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler): table_opts = [] - opts = dict( - (k[len(self.dialect.name) + 1 :].upper(), v) + opts = { + k[len(self.dialect.name) + 1 :].upper(): v for k, v in table.kwargs.items() if k.startswith("%s_" % self.dialect.name) - ) + } if table.comment is not None: opts["COMMENT"] = table.comment @@ -1963,9 +1963,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): return text def visit_primary_key_constraint(self, constraint): - text = super(MySQLDDLCompiler, self).visit_primary_key_constraint( - constraint - ) + text = super().visit_primary_key_constraint(constraint) using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) @@ -2305,7 +2303,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_enum(self, type_, **kw): if not type_.native_enum: - return super(MySQLTypeCompiler, self).visit_enum(type_) + return super().visit_enum(type_) else: return self._visit_enumerated_values("ENUM", type_, type_.enums) @@ -2351,9 +2349,7 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): else: quote = '"' - super(MySQLIdentifierPreparer, self).__init__( - dialect, initial_quote=quote, escape_quote=quote - ) + super().__init__(dialect, initial_quote=quote, escape_quote=quote) def _quote_free_identifiers(self, *ids): """Unilaterally identifier-quote any number of strings.""" diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py index 8dc96fb15..350458877 100644 --- a/lib/sqlalchemy/dialects/mysql/enumerated.py +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -84,7 +84,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): if elem == "": return elem else: - return super(ENUM, self)._object_value_for_elem(elem) + return super()._object_value_for_elem(elem) def __repr__(self): return util.generic_repr( @@ -153,15 +153,15 @@ class SET(_StringType): "setting retrieve_as_bitwise=True" ) if self.retrieve_as_bitwise: - self._bitmap = dict( - (value, 2**idx) for idx, value in enumerate(self.values) - ) + self._bitmap = { + value: 2**idx for idx, value in enumerate(self.values) + } self._bitmap.update( (2**idx, value) for idx, value in enumerate(self.values) ) length = max([len(v) for v in values] + [0]) kw.setdefault("length", length) - super(SET, self).__init__(**kw) + super().__init__(**kw) def column_expression(self, colexpr): if self.retrieve_as_bitwise: @@ -183,7 +183,7 @@ class SET(_StringType): return None else: - super_convert = super(SET, self).result_processor(dialect, coltype) + super_convert = super().result_processor(dialect, coltype) def process(value): if isinstance(value, str): @@ -201,7 +201,7 @@ class SET(_StringType): return process def bind_processor(self, dialect): - super_convert = super(SET, self).bind_processor(dialect) + super_convert = super().bind_processor(dialect) if self.retrieve_as_bitwise: def process(value): diff --git a/lib/sqlalchemy/dialects/mysql/expression.py b/lib/sqlalchemy/dialects/mysql/expression.py index c8c693517..561803a78 100644 --- a/lib/sqlalchemy/dialects/mysql/expression.py +++ b/lib/sqlalchemy/dialects/mysql/expression.py @@ -107,9 +107,7 @@ class match(Generative, elements.BinaryExpression): if kw: raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw))) - super(match, self).__init__( - left, against, operators.match_op, modifiers=flags - ) + super().__init__(left, against, operators.match_op, modifiers=flags) @_generative def in_boolean_mode(self: Selfmatch) -> Selfmatch: diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index 6327d8687..a3f288ceb 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -102,7 +102,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): return (99, 99, 99) def __init__(self, **kwargs): - super(MySQLDialect_mariadbconnector, self).__init__(**kwargs) + super().__init__(**kwargs) self.paramstyle = "qmark" if self.dbapi is not None: if self._dbapi_version < mariadb_cpy_minimum_version: @@ -117,9 +117,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): return __import__("mariadb") def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_mariadbconnector, self).is_disconnect( - e, connection, cursor - ): + if super().is_disconnect(e, connection, cursor): return True elif isinstance(e, self.dbapi.Error): str_e = str(e).lower() @@ -188,9 +186,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): connection.autocommit = True else: connection.autocommit = False - super(MySQLDialect_mariadbconnector, self).set_isolation_level( - connection, level - ) + super().set_isolation_level(connection, level) def do_begin_twophase(self, connection, xid): connection.execute( diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 58e92c4ab..f29a5008c 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -167,24 +167,20 @@ class MySQLDialect_mysqlconnector(MySQLDialect): def _compat_fetchone(self, rp, charset=None): return rp.fetchone() - _isolation_lookup = set( - [ - "SERIALIZABLE", - "READ UNCOMMITTED", - "READ COMMITTED", - "REPEATABLE READ", - "AUTOCOMMIT", - ] - ) + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + } def _set_isolation_level(self, connection, level): if level == "AUTOCOMMIT": connection.autocommit = True else: connection.autocommit = False - super(MySQLDialect_mysqlconnector, self)._set_isolation_level( - connection, level - ) + super()._set_isolation_level(connection, level) dialect = MySQLDialect_mysqlconnector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 60b9cb103..9eb1ef84a 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -137,7 +137,7 @@ class MySQLDialect_mysqldb(MySQLDialect): preparer = MySQLIdentifierPreparer def __init__(self, **kwargs): - super(MySQLDialect_mysqldb, self).__init__(**kwargs) + super().__init__(**kwargs) self._mysql_dbapi_version = ( self._parse_dbapi_version(self.dbapi.__version__) if self.dbapi is not None and hasattr(self.dbapi, "__version__") @@ -165,7 +165,7 @@ class MySQLDialect_mysqldb(MySQLDialect): return __import__("MySQLdb") def on_connect(self): - super_ = super(MySQLDialect_mysqldb, self).on_connect() + super_ = super().on_connect() def on_connect(conn): if super_ is not None: @@ -221,9 +221,7 @@ class MySQLDialect_mysqldb(MySQLDialect): ] else: additional_tests = [] - return super(MySQLDialect_mysqldb, self)._check_unicode_returns( - connection, additional_tests - ) + return super()._check_unicode_returns(connection, additional_tests) def create_connect_args(self, url, _translate_args=None): if _translate_args is None: @@ -324,9 +322,7 @@ class MySQLDialect_mysqldb(MySQLDialect): dbapi_connection.autocommit(True) else: dbapi_connection.autocommit(False) - super(MySQLDialect_mysqldb, self).set_isolation_level( - dbapi_connection, level - ) + super().set_isolation_level(dbapi_connection, level) dialect = MySQLDialect_mysqldb diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 66d2f3242..8a194d7fb 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -65,14 +65,12 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): def create_connect_args(self, url, _translate_args=None): if _translate_args is None: _translate_args = dict(username="user") - return super(MySQLDialect_pymysql, self).create_connect_args( + return super().create_connect_args( url, _translate_args=_translate_args ) def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_pymysql, self).is_disconnect( - e, connection, cursor - ): + if super().is_disconnect(e, connection, cursor): return True elif isinstance(e, self.dbapi.Error): str_e = str(e).lower() diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 2d31dfe5f..f9464f39f 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -118,7 +118,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): return None def on_connect(self): - super_ = super(MySQLDialect_pyodbc, self).on_connect() + super_ = super().on_connect() def on_connect(conn): if super_ is not None: diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py index 44bc62179..fa1b7e0b7 100644 --- a/lib/sqlalchemy/dialects/mysql/reflection.py +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -340,9 +340,9 @@ class MySQLTableDefinitionParser: buffer = [] for row in columns: - (name, col_type, nullable, default, extra) = [ + (name, col_type, nullable, default, extra) = ( row[i] for i in (0, 1, 2, 4, 5) - ] + ) line = [" "] line.append(self.preparer.quote_identifier(name)) diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py index a74fba177..5a96b890b 100644 --- a/lib/sqlalchemy/dialects/mysql/types.py +++ b/lib/sqlalchemy/dialects/mysql/types.py @@ -25,7 +25,7 @@ class _NumericType: def __init__(self, unsigned=False, zerofill=False, **kw): self.unsigned = unsigned self.zerofill = zerofill - super(_NumericType, self).__init__(**kw) + super().__init__(**kw) def __repr__(self): return util.generic_repr( @@ -43,9 +43,7 @@ class _FloatType(_NumericType, sqltypes.Float): "You must specify both precision and scale or omit " "both altogether." ) - super(_FloatType, self).__init__( - precision=precision, asdecimal=asdecimal, **kw - ) + super().__init__(precision=precision, asdecimal=asdecimal, **kw) self.scale = scale def __repr__(self): @@ -57,7 +55,7 @@ class _FloatType(_NumericType, sqltypes.Float): class _IntegerType(_NumericType, sqltypes.Integer): def __init__(self, display_width=None, **kw): self.display_width = display_width - super(_IntegerType, self).__init__(**kw) + super().__init__(**kw) def __repr__(self): return util.generic_repr( @@ -87,7 +85,7 @@ class _StringType(sqltypes.String): self.unicode = unicode self.binary = binary self.national = national - super(_StringType, self).__init__(**kw) + super().__init__(**kw) def __repr__(self): return util.generic_repr( @@ -123,7 +121,7 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC): numeric. """ - super(NUMERIC, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -149,7 +147,7 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL): numeric. """ - super(DECIMAL, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -183,7 +181,7 @@ class DOUBLE(_FloatType, sqltypes.DOUBLE): numeric. """ - super(DOUBLE, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -217,7 +215,7 @@ class REAL(_FloatType, sqltypes.REAL): numeric. """ - super(REAL, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -243,7 +241,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT): numeric. """ - super(FLOAT, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -269,7 +267,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER): numeric. """ - super(INTEGER, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class BIGINT(_IntegerType, sqltypes.BIGINT): @@ -290,7 +288,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT): numeric. """ - super(BIGINT, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class MEDIUMINT(_IntegerType): @@ -311,7 +309,7 @@ class MEDIUMINT(_IntegerType): numeric. """ - super(MEDIUMINT, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class TINYINT(_IntegerType): @@ -332,7 +330,7 @@ class TINYINT(_IntegerType): numeric. """ - super(TINYINT, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class SMALLINT(_IntegerType, sqltypes.SMALLINT): @@ -353,7 +351,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT): numeric. """ - super(SMALLINT, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class BIT(sqltypes.TypeEngine): @@ -417,7 +415,7 @@ class TIME(sqltypes.TIME): MySQL Connector/Python. """ - super(TIME, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.fsp = fsp def result_processor(self, dialect, coltype): @@ -462,7 +460,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): MySQL Connector/Python. """ - super(TIMESTAMP, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.fsp = fsp @@ -487,7 +485,7 @@ class DATETIME(sqltypes.DATETIME): MySQL Connector/Python. """ - super(DATETIME, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.fsp = fsp @@ -533,7 +531,7 @@ class TEXT(_StringType, sqltypes.TEXT): only the collation of character data. """ - super(TEXT, self).__init__(length=length, **kw) + super().__init__(length=length, **kw) class TINYTEXT(_StringType): @@ -565,7 +563,7 @@ class TINYTEXT(_StringType): only the collation of character data. """ - super(TINYTEXT, self).__init__(**kwargs) + super().__init__(**kwargs) class MEDIUMTEXT(_StringType): @@ -597,7 +595,7 @@ class MEDIUMTEXT(_StringType): only the collation of character data. """ - super(MEDIUMTEXT, self).__init__(**kwargs) + super().__init__(**kwargs) class LONGTEXT(_StringType): @@ -629,7 +627,7 @@ class LONGTEXT(_StringType): only the collation of character data. """ - super(LONGTEXT, self).__init__(**kwargs) + super().__init__(**kwargs) class VARCHAR(_StringType, sqltypes.VARCHAR): @@ -661,7 +659,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): only the collation of character data. """ - super(VARCHAR, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class CHAR(_StringType, sqltypes.CHAR): @@ -682,7 +680,7 @@ class CHAR(_StringType, sqltypes.CHAR): compatible with the national character set. """ - super(CHAR, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) @classmethod def _adapt_string_for_cast(self, type_): @@ -728,7 +726,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): """ kwargs["national"] = True - super(NVARCHAR, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class NCHAR(_StringType, sqltypes.NCHAR): @@ -754,7 +752,7 @@ class NCHAR(_StringType, sqltypes.NCHAR): """ kwargs["national"] = True - super(NCHAR, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class TINYBLOB(sqltypes._Binary): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 6481ae483..0d51bf73d 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -813,7 +813,7 @@ class OracleCompiler(compiler.SQLCompiler): def __init__(self, *args, **kwargs): self.__wheres = {} - super(OracleCompiler, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def visit_mod_binary(self, binary, operator, **kw): return "mod(%s, %s)" % ( @@ -852,15 +852,13 @@ class OracleCompiler(compiler.SQLCompiler): return "" def visit_function(self, func, **kw): - text = super(OracleCompiler, self).visit_function(func, **kw) + text = super().visit_function(func, **kw) if kw.get("asfrom", False): text = "TABLE (%s)" % func return text def visit_table_valued_column(self, element, **kw): - text = super(OracleCompiler, self).visit_table_valued_column( - element, **kw - ) + text = super().visit_table_valued_column(element, **kw) text = "COLUMN_VALUE " + text return text @@ -1331,9 +1329,7 @@ class OracleDDLCompiler(compiler.DDLCompiler): return "".join(table_opts) def get_identity_options(self, identity_options): - text = super(OracleDDLCompiler, self).get_identity_options( - identity_options - ) + text = super().get_identity_options(identity_options) text = text.replace("NO MINVALUE", "NOMINVALUE") text = text.replace("NO MAXVALUE", "NOMAXVALUE") text = text.replace("NO CYCLE", "NOCYCLE") @@ -1386,9 +1382,7 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer): def format_savepoint(self, savepoint): name = savepoint.ident.lstrip("_") - return super(OracleIdentifierPreparer, self).format_savepoint( - savepoint, name - ) + return super().format_savepoint(savepoint, name) class OracleExecutionContext(default.DefaultExecutionContext): @@ -1489,7 +1483,7 @@ class OracleDialect(default.DefaultDialect): ) = enable_offset_fetch def initialize(self, connection): - super(OracleDialect, self).initialize(connection) + super().initialize(connection) # Oracle 8i has RETURNING: # https://docs.oracle.com/cd/A87860_01/doc/index.htm diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 5a0c0e160..0be309cd4 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -672,9 +672,7 @@ class _OracleBinary(_LOBDataType, sqltypes.LargeBinary): if not dialect.auto_convert_lobs: return None else: - return super(_OracleBinary, self).result_processor( - dialect, coltype - ) + return super().result_processor(dialect, coltype) class _OracleInterval(oracle.INTERVAL): diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index 60a8ebcb5..5cea62b9f 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -35,12 +35,10 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer): if asdecimal is None: asdecimal = bool(scale and scale > 0) - super(NUMBER, self).__init__( - precision=precision, scale=scale, asdecimal=asdecimal - ) + super().__init__(precision=precision, scale=scale, asdecimal=asdecimal) def adapt(self, impltype): - ret = super(NUMBER, self).adapt(impltype) + ret = super().adapt(impltype) # leave a hint for the DBAPI handler ret._is_oracle_number = True return ret diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index 92341d2da..4bb1026a5 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -60,15 +60,13 @@ class _PsycopgHStore(HSTORE): if dialect._has_native_hstore: return None else: - return super(_PsycopgHStore, self).bind_processor(dialect) + return super().bind_processor(dialect) def result_processor(self, dialect, coltype): if dialect._has_native_hstore: return None else: - return super(_PsycopgHStore, self).result_processor( - dialect, coltype - ) + return super().result_processor(dialect, coltype) class _PsycopgARRAY(PGARRAY): diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 3132e875e..e130eccc2 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -101,7 +101,7 @@ class array(expression.ExpressionClauseList[_T]): def __init__(self, clauses, **kw): type_arg = kw.pop("type_", None) - super(array, self).__init__(operators.comma_op, *clauses, **kw) + super().__init__(operators.comma_op, *clauses, **kw) self._type_tuple = [arg.type for arg in self.clauses] diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index cd161d28e..751dc3dcf 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -560,7 +560,7 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): __slots__ = ("_rowbuffer",) def __init__(self, adapt_connection): - super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection) + super().__init__(adapt_connection) self._rowbuffer = None def close(self): @@ -863,9 +863,7 @@ class AsyncAdapt_asyncpg_dbapi: class InvalidCachedStatementError(NotSupportedError): def __init__(self, message): - super( - AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self - ).__init__( + super().__init__( message + " (SQLAlchemy asyncpg dialect will now invalidate " "all prepared caches in response to this exception)", ) @@ -1095,7 +1093,7 @@ class PGDialect_asyncpg(PGDialect): """ - super_connect = super(PGDialect_asyncpg, self).on_connect() + super_connect = super().on_connect() def connect(conn): conn.await_(self.setup_asyncpg_json_codec(conn)) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index a908ed6b7..49ee89daa 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1469,112 +1469,110 @@ from ...util.typing import TypedDict IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I) -RESERVED_WORDS = set( - [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "leading", - "limit", - "localtime", - "localtimestamp", - "new", - "not", - "null", - "of", - "off", - "offset", - "old", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", - "authorization", - "between", - "binary", - "cross", - "current_schema", - "freeze", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "notnull", - "outer", - "over", - "overlaps", - "right", - "similar", - "verbose", - ] -) +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "leading", + "limit", + "localtime", + "localtimestamp", + "new", + "not", + "null", + "of", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", + "authorization", + "between", + "binary", + "cross", + "current_schema", + "freeze", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "notnull", + "outer", + "over", + "overlaps", + "right", + "similar", + "verbose", +} colspecs = { sqltypes.ARRAY: _array.ARRAY, @@ -1801,7 +1799,7 @@ class PGCompiler(compiler.SQLCompiler): ) def render_literal_value(self, value, type_): - value = super(PGCompiler, self).render_literal_value(value, type_) + value = super().render_literal_value(value, type_) if self.dialect._backslash_escapes: value = value.replace("\\", "\\\\") @@ -2108,14 +2106,12 @@ class PGDDLCompiler(compiler.DDLCompiler): "create_constraint=False on this Enum datatype." ) - text = super(PGDDLCompiler, self).visit_check_constraint(constraint) + text = super().visit_check_constraint(constraint) text += self._define_constraint_validity(constraint) return text def visit_foreign_key_constraint(self, constraint): - text = super(PGDDLCompiler, self).visit_foreign_key_constraint( - constraint - ) + text = super().visit_foreign_key_constraint(constraint) text += self._define_constraint_validity(constraint) return text @@ -2353,9 +2349,7 @@ class PGDDLCompiler(compiler.DDLCompiler): create.element.data_type ) - return super(PGDDLCompiler, self).visit_create_sequence( - create, prefix=prefix, **kw - ) + return super().visit_create_sequence(create, prefix=prefix, **kw) def _can_comment_on_constraint(self, ddl_instance): constraint = ddl_instance.element @@ -2478,7 +2472,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_enum(self, type_, **kw): if not type_.native_enum or not self.dialect.supports_native_enum: - return super(PGTypeCompiler, self).visit_enum(type_, **kw) + return super().visit_enum(type_, **kw) else: return self.visit_ENUM(type_, **kw) @@ -2803,7 +2797,7 @@ class PGExecutionContext(default.DefaultExecutionContext): return self._execute_scalar(exc, column.type) - return super(PGExecutionContext, self).get_insert_default(column) + return super().get_insert_default(column) class PGReadOnlyConnectionCharacteristic( @@ -2945,7 +2939,7 @@ class PGDialect(default.DefaultDialect): self._json_serializer = json_serializer def initialize(self, connection): - super(PGDialect, self).initialize(connection) + super().initialize(connection) # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 self.supports_smallserial = self.server_version_info >= (9, 2) diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index b79b4a30e..645bedf17 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -266,7 +266,7 @@ class OnConflictDoUpdate(OnConflictClause): set_=None, where=None, ): - super(OnConflictDoUpdate, self).__init__( + super().__init__( constraint=constraint, index_elements=index_elements, index_where=index_where, diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index ebaad2734..b0d8ef345 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -243,7 +243,7 @@ class ExcludeConstraint(ColumnCollectionConstraint): self.ops = kw.get("ops", {}) def _set_parent(self, table, **kw): - super(ExcludeConstraint, self)._set_parent(table) + super()._set_parent(table) self._render_exprs = [ ( diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index a8b03bd48..c68671918 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -221,7 +221,7 @@ class JSON(sqltypes.JSON): .. versionadded:: 1.1 """ - super(JSON, self).__init__(none_as_null=none_as_null) + super().__init__(none_as_null=none_as_null) if astext_type is not None: self.astext_type = astext_type diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index f844e9213..79b567f08 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -31,8 +31,8 @@ class NamedType(sqltypes.TypeEngine): """Base for named types.""" __abstract__ = True - DDLGenerator: Type["NamedTypeGenerator"] - DDLDropper: Type["NamedTypeDropper"] + DDLGenerator: Type[NamedTypeGenerator] + DDLDropper: Type[NamedTypeDropper] create_type: bool def create(self, bind, checkfirst=True, **kw): diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index cb5cab178..5acd50710 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -651,7 +651,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg): ) def initialize(self, connection): - super(PGDialect_psycopg2, self).initialize(connection) + super().initialize(connection) self._has_native_hstore = ( self.use_native_hstore and self._hstore_oids(connection.connection.dbapi_connection) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 81b677187..72703ff81 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -128,7 +128,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): .. versionadded:: 1.4 """ - super(TIMESTAMP, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.precision = precision @@ -147,7 +147,7 @@ class TIME(sqltypes.TIME): .. versionadded:: 1.4 """ - super(TIME, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.precision = precision diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c3cb10cef..4e5808f62 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -921,9 +921,7 @@ from ...types import VARCHAR # noqa class _SQliteJson(JSON): def result_processor(self, dialect, coltype): - default_processor = super(_SQliteJson, self).result_processor( - dialect, coltype - ) + default_processor = super().result_processor(dialect, coltype) def process(value): try: @@ -942,7 +940,7 @@ class _DateTimeMixin: _storage_format = None def __init__(self, storage_format=None, regexp=None, **kw): - super(_DateTimeMixin, self).__init__(**kw) + super().__init__(**kw) if regexp is not None: self._reg = re.compile(regexp) if storage_format is not None: @@ -978,7 +976,7 @@ class _DateTimeMixin: kw["storage_format"] = self._storage_format if self._reg: kw["regexp"] = self._reg - return super(_DateTimeMixin, self).adapt(cls, **kw) + return super().adapt(cls, **kw) def literal_processor(self, dialect): bp = self.bind_processor(dialect) @@ -1037,7 +1035,7 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): def __init__(self, *args, **kwargs): truncate_microseconds = kwargs.pop("truncate_microseconds", False) - super(DATETIME, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) if truncate_microseconds: assert "storage_format" not in kwargs, ( "You can specify only " @@ -1215,7 +1213,7 @@ class TIME(_DateTimeMixin, sqltypes.Time): def __init__(self, *args, **kwargs): truncate_microseconds = kwargs.pop("truncate_microseconds", False) - super(TIME, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) if truncate_microseconds: assert "storage_format" not in kwargs, ( "You can specify only " @@ -1337,7 +1335,7 @@ class SQLiteCompiler(compiler.SQLCompiler): def visit_cast(self, cast, **kwargs): if self.dialect.supports_cast: - return super(SQLiteCompiler, self).visit_cast(cast, **kwargs) + return super().visit_cast(cast, **kwargs) else: return self.process(cast.clause, **kwargs) @@ -1610,9 +1608,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): ): return None - text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint( - constraint - ) + text = super().visit_primary_key_constraint(constraint) on_conflict_clause = constraint.dialect_options["sqlite"][ "on_conflict" @@ -1628,9 +1624,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_unique_constraint(self, constraint): - text = super(SQLiteDDLCompiler, self).visit_unique_constraint( - constraint - ) + text = super().visit_unique_constraint(constraint) on_conflict_clause = constraint.dialect_options["sqlite"][ "on_conflict" @@ -1648,9 +1642,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_check_constraint(self, constraint): - text = super(SQLiteDDLCompiler, self).visit_check_constraint( - constraint - ) + text = super().visit_check_constraint(constraint) on_conflict_clause = constraint.dialect_options["sqlite"][ "on_conflict" @@ -1662,9 +1654,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_column_check_constraint(self, constraint): - text = super(SQLiteDDLCompiler, self).visit_column_check_constraint( - constraint - ) + text = super().visit_column_check_constraint(constraint) if constraint.dialect_options["sqlite"]["on_conflict"] is not None: raise exc.CompileError( @@ -1682,9 +1672,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if local_table.schema != remote_table.schema: return None else: - return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint( - constraint - ) + return super().visit_foreign_key_constraint(constraint) def define_constraint_remote_table(self, constraint, table, preparer): """Format the remote table clause of a CREATE CONSTRAINT clause.""" @@ -1741,7 +1729,7 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): not isinstance(type_, _DateTimeMixin) or type_.format_is_text_affinity ): - return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) + return super().visit_DATETIME(type_) else: return "DATETIME_CHAR" @@ -1750,7 +1738,7 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): not isinstance(type_, _DateTimeMixin) or type_.format_is_text_affinity ): - return super(SQLiteTypeCompiler, self).visit_DATE(type_) + return super().visit_DATE(type_) else: return "DATE_CHAR" @@ -1759,7 +1747,7 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): not isinstance(type_, _DateTimeMixin) or type_.format_is_text_affinity ): - return super(SQLiteTypeCompiler, self).visit_TIME(type_) + return super().visit_TIME(type_) else: return "TIME_CHAR" @@ -1771,127 +1759,125 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set( - [ - "add", - "after", - "all", - "alter", - "analyze", - "and", - "as", - "asc", - "attach", - "autoincrement", - "before", - "begin", - "between", - "by", - "cascade", - "case", - "cast", - "check", - "collate", - "column", - "commit", - "conflict", - "constraint", - "create", - "cross", - "current_date", - "current_time", - "current_timestamp", - "database", - "default", - "deferrable", - "deferred", - "delete", - "desc", - "detach", - "distinct", - "drop", - "each", - "else", - "end", - "escape", - "except", - "exclusive", - "exists", - "explain", - "false", - "fail", - "for", - "foreign", - "from", - "full", - "glob", - "group", - "having", - "if", - "ignore", - "immediate", - "in", - "index", - "indexed", - "initially", - "inner", - "insert", - "instead", - "intersect", - "into", - "is", - "isnull", - "join", - "key", - "left", - "like", - "limit", - "match", - "natural", - "not", - "notnull", - "null", - "of", - "offset", - "on", - "or", - "order", - "outer", - "plan", - "pragma", - "primary", - "query", - "raise", - "references", - "reindex", - "rename", - "replace", - "restrict", - "right", - "rollback", - "row", - "select", - "set", - "table", - "temp", - "temporary", - "then", - "to", - "transaction", - "trigger", - "true", - "union", - "unique", - "update", - "using", - "vacuum", - "values", - "view", - "virtual", - "when", - "where", - ] - ) + reserved_words = { + "add", + "after", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "attach", + "autoincrement", + "before", + "begin", + "between", + "by", + "cascade", + "case", + "cast", + "check", + "collate", + "column", + "commit", + "conflict", + "constraint", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "database", + "default", + "deferrable", + "deferred", + "delete", + "desc", + "detach", + "distinct", + "drop", + "each", + "else", + "end", + "escape", + "except", + "exclusive", + "exists", + "explain", + "false", + "fail", + "for", + "foreign", + "from", + "full", + "glob", + "group", + "having", + "if", + "ignore", + "immediate", + "in", + "index", + "indexed", + "initially", + "inner", + "insert", + "instead", + "intersect", + "into", + "is", + "isnull", + "join", + "key", + "left", + "like", + "limit", + "match", + "natural", + "not", + "notnull", + "null", + "of", + "offset", + "on", + "or", + "order", + "outer", + "plan", + "pragma", + "primary", + "query", + "raise", + "references", + "reindex", + "rename", + "replace", + "restrict", + "right", + "rollback", + "row", + "select", + "set", + "table", + "temp", + "temporary", + "then", + "to", + "transaction", + "trigger", + "true", + "union", + "unique", + "update", + "using", + "vacuum", + "values", + "view", + "virtual", + "when", + "where", + } class SQLiteExecutionContext(default.DefaultExecutionContext): @@ -2454,17 +2440,14 @@ class SQLiteDialect(default.DefaultDialect): # the names as well. SQLite saves the DDL in whatever format # it was typed in as, so need to be liberal here. - keys_by_signature = dict( - ( - fk_sig( - fk["constrained_columns"], - fk["referred_table"], - fk["referred_columns"], - ), - fk, - ) + keys_by_signature = { + fk_sig( + fk["constrained_columns"], + fk["referred_table"], + fk["referred_columns"], + ): fk for fk in fks.values() - ) + } table_data = self._get_table_sql(connection, table_name, schema=schema) diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index 9e9e68330..0777c9261 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -202,7 +202,7 @@ class OnConflictDoUpdate(OnConflictClause): set_=None, where=None, ): - super(OnConflictDoUpdate, self).__init__( + super().__init__( index_elements=index_elements, index_where=index_where, ) diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py index 53e4b0d1b..5c07f487c 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -124,9 +124,7 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): return pool.SingletonThreadPool def on_connect_url(self, url): - super_on_connect = super( - SQLiteDialect_pysqlcipher, self - ).on_connect_url(url) + super_on_connect = super().on_connect_url(url) # pull the info we need from the URL early. Even though URL # is immutable, we don't want any in-place changes to the URL @@ -151,9 +149,7 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): def create_connect_args(self, url): plain_url = url._replace(password=None) plain_url = plain_url.difference_update_query(self.pragmas) - return super(SQLiteDialect_pysqlcipher, self).create_connect_args( - plain_url - ) + return super().create_connect_args(plain_url) dialect = SQLiteDialect_pysqlcipher diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 19949441f..4475ccae7 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -536,9 +536,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): dbapi_connection.isolation_level = None else: dbapi_connection.isolation_level = "" - return super(SQLiteDialect_pysqlite, self).set_isolation_level( - dbapi_connection, level - ) + return super().set_isolation_level(dbapi_connection, level) def on_connect(self): def regexp(a, b): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index b2f6b29b7..b686de0d6 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -2854,7 +2854,7 @@ class TwoPhaseTransaction(RootTransaction): def __init__(self, connection: Connection, xid: Any): self._is_prepared = False self.xid = xid - super(TwoPhaseTransaction, self).__init__(connection) + super().__init__(connection) def prepare(self) -> None: """Prepare this :class:`.TwoPhaseTransaction`. diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 1ad8c90e7..c8736392f 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -115,7 +115,7 @@ def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: "is deprecated and will be removed in a future release. ", ), ) -def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> Engine: +def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: """Create a new :class:`_engine.Engine` instance. The standard calling form is to send the :ref:`URL <database_urls>` as the @@ -806,11 +806,11 @@ def engine_from_config( """ - options = dict( - (key[len(prefix) :], configuration[key]) + options = { + key[len(prefix) :]: configuration[key] for key in configuration if key.startswith(prefix) - ) + } options["_coerce_config"] = True options.update(kwargs) url = options.pop("url") diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index f22e89fbe..33ee7866c 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1230,15 +1230,11 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): def soft_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(BufferedRowCursorFetchStrategy, self).soft_close( - result, dbapi_cursor - ) + super().soft_close(result, dbapi_cursor) def hard_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(BufferedRowCursorFetchStrategy, self).hard_close( - result, dbapi_cursor - ) + super().hard_close(result, dbapi_cursor) def fetchone(self, result, dbapi_cursor, hard_close=False): if not self._rowbuffer: @@ -1307,15 +1303,11 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): def soft_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(FullyBufferedCursorFetchStrategy, self).soft_close( - result, dbapi_cursor - ) + super().soft_close(result, dbapi_cursor) def hard_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(FullyBufferedCursorFetchStrategy, self).hard_close( - result, dbapi_cursor - ) + super().hard_close(result, dbapi_cursor) def fetchone(self, result, dbapi_cursor, hard_close=False): if self._rowbuffer: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index e5d613dd5..3cc9cab8b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -200,9 +200,7 @@ class DefaultDialect(Dialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True - colspecs: MutableMapping[ - Type["TypeEngine[Any]"], Type["TypeEngine[Any]"] - ] = {} + colspecs: MutableMapping[Type[TypeEngine[Any]], Type[TypeEngine[Any]]] = {} default_paramstyle = "named" supports_default_values = False @@ -1486,21 +1484,17 @@ class DefaultExecutionContext(ExecutionContext): use_server_side = self.execution_options.get( "stream_results", True ) and ( - ( - self.compiled - and isinstance( - self.compiled.statement, expression.Selectable - ) - or ( - ( - not self.compiled - or isinstance( - self.compiled.statement, expression.TextClause - ) + self.compiled + and isinstance(self.compiled.statement, expression.Selectable) + or ( + ( + not self.compiled + or isinstance( + self.compiled.statement, expression.TextClause ) - and self.unicode_statement - and SERVER_SIDE_CURSOR_RE.match(self.unicode_statement) ) + and self.unicode_statement + and SERVER_SIDE_CURSOR_RE.match(self.unicode_statement) ) ) else: @@ -1938,15 +1932,12 @@ class DefaultExecutionContext(ExecutionContext): ] ) else: - parameters = dict( - ( - key, - processors[key](compiled_params[key]) # type: ignore - if key in processors - else compiled_params[key], - ) + parameters = { + key: processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] for key in compiled_params - ) + } return self._execute_scalar( str(compiled), type_, parameters=parameters ) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index e10fab831..2f5efce25 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -757,7 +757,7 @@ class Dialect(EventTarget): # create_engine() -> isolation_level currently goes here _on_connect_isolation_level: Optional[IsolationLevel] - execution_ctx_cls: Type["ExecutionContext"] + execution_ctx_cls: Type[ExecutionContext] """a :class:`.ExecutionContext` class used to handle statement execution""" execute_sequence_format: Union[ @@ -963,7 +963,7 @@ class Dialect(EventTarget): """target database, when given a CTE with an INSERT statement, needs the CTE to be below the INSERT""" - colspecs: MutableMapping[Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]] + colspecs: MutableMapping[Type[TypeEngine[Any]], Type[TypeEngine[Any]]] """A dictionary of TypeEngine classes from sqlalchemy.types mapped to subclasses that are specific to the dialect class. This dictionary is class-level only and is not accessed from the @@ -1160,12 +1160,12 @@ class Dialect(EventTarget): _bind_typing_render_casts: bool - _type_memos: MutableMapping[TypeEngine[Any], "_TypeMemoDict"] + _type_memos: MutableMapping[TypeEngine[Any], _TypeMemoDict] def _builtin_onconnect(self) -> Optional[_ListenerFnType]: raise NotImplementedError() - def create_connect_args(self, url: "URL") -> ConnectArgsType: + def create_connect_args(self, url: URL) -> ConnectArgsType: """Build DB-API compatible connection arguments. Given a :class:`.URL` object, returns a tuple @@ -1217,7 +1217,7 @@ class Dialect(EventTarget): raise NotImplementedError() @classmethod - def type_descriptor(cls, typeobj: "TypeEngine[_T]") -> "TypeEngine[_T]": + def type_descriptor(cls, typeobj: TypeEngine[_T]) -> TypeEngine[_T]: """Transform a generic type to a dialect-specific type. Dialect classes will usually use the @@ -2155,7 +2155,7 @@ class Dialect(EventTarget): self, cursor: DBAPICursor, statement: str, - context: Optional["ExecutionContext"] = None, + context: Optional[ExecutionContext] = None, ) -> None: """Provide an implementation of ``cursor.execute(statement)``. @@ -2210,7 +2210,7 @@ class Dialect(EventTarget): """ raise NotImplementedError() - def on_connect_url(self, url: "URL") -> Optional[Callable[[Any], Any]]: + def on_connect_url(self, url: URL) -> Optional[Callable[[Any], Any]]: """return a callable which sets up a newly created DBAPI connection. This method is a new hook that supersedes the @@ -2556,7 +2556,7 @@ class Dialect(EventTarget): """ @classmethod - def engine_created(cls, engine: "Engine") -> None: + def engine_created(cls, engine: Engine) -> None: """A convenience hook called before returning the final :class:`_engine.Engine`. diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index f744d53ad..d1669cc3c 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -568,9 +568,9 @@ class Inspector(inspection.Inspectable["Inspector"]): schema_fkeys = self.get_multi_foreign_keys(schname, **kw) tnames.extend(schema_fkeys) for (_, tname), fkeys in schema_fkeys.items(): - fknames_for_table[(schname, tname)] = set( - [fk["name"] for fk in fkeys] - ) + fknames_for_table[(schname, tname)] = { + fk["name"] for fk in fkeys + } for fkey in fkeys: if ( tname != fkey["referred_table"] @@ -1517,11 +1517,11 @@ class Inspector(inspection.Inspectable["Inspector"]): # intended for reflection, e.g. oracle_resolve_synonyms. # these are unconditionally passed to related Table # objects - reflection_options = dict( - (k, table.dialect_kwargs.get(k)) + reflection_options = { + k: table.dialect_kwargs.get(k) for k in dialect.reflection_options if k in table.dialect_kwargs - ) + } table_key = (schema, table_name) if _reflect_info is None or table_key not in _reflect_info.columns: @@ -1644,8 +1644,8 @@ class Inspector(inspection.Inspectable["Inspector"]): coltype = col_d["type"] - col_kw = dict( - (k, col_d[k]) # type: ignore[literal-required] + col_kw = { + k: col_d[k] # type: ignore[literal-required] for k in [ "nullable", "autoincrement", @@ -1655,7 +1655,7 @@ class Inspector(inspection.Inspectable["Inspector"]): "comment", ] if k in col_d - ) + } if "dialect_options" in col_d: col_kw.update(col_d["dialect_options"]) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index bcd2f0ea9..392cefa02 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -2342,7 +2342,7 @@ class ChunkedIteratorResult(IteratorResult[_TP]): return self def _soft_close(self, hard: bool = False, **kw: Any) -> None: - super(ChunkedIteratorResult, self)._soft_close(hard=hard, **kw) + super()._soft_close(hard=hard, **kw) self.chunks = lambda size: [] # type: ignore def _fetchmany_impl( @@ -2370,7 +2370,7 @@ class MergedResult(IteratorResult[_TP]): self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]] ): self._results = results - super(MergedResult, self).__init__( + super().__init__( cursor_metadata, itertools.chain.from_iterable( r._raw_row_iterator() for r in results diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 3d4341410..a06b05940 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -58,7 +58,7 @@ def _legacy_signature( def _wrap_fn_for_legacy( - dispatch_collection: "_ClsLevelDispatch[_ET]", + dispatch_collection: _ClsLevelDispatch[_ET], fn: _ListenerFnType, argspec: FullArgSpec, ) -> _ListenerFnType: @@ -120,7 +120,7 @@ def _indent(text: str, indent: str) -> str: def _standard_listen_example( - dispatch_collection: "_ClsLevelDispatch[_ET]", + dispatch_collection: _ClsLevelDispatch[_ET], sample_target: Any, fn: _ListenerFnType, ) -> str: @@ -161,7 +161,7 @@ def _standard_listen_example( def _legacy_listen_examples( - dispatch_collection: "_ClsLevelDispatch[_ET]", + dispatch_collection: _ClsLevelDispatch[_ET], sample_target: str, fn: _ListenerFnType, ) -> str: @@ -189,8 +189,8 @@ def _legacy_listen_examples( def _version_signature_changes( - parent_dispatch_cls: Type["_HasEventsDispatch[_ET]"], - dispatch_collection: "_ClsLevelDispatch[_ET]", + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + dispatch_collection: _ClsLevelDispatch[_ET], ) -> str: since, args, conv = dispatch_collection.legacy_signatures[0] return ( @@ -219,8 +219,8 @@ def _version_signature_changes( def _augment_fn_docs( - dispatch_collection: "_ClsLevelDispatch[_ET]", - parent_dispatch_cls: Type["_HasEventsDispatch[_ET]"], + dispatch_collection: _ClsLevelDispatch[_ET], + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], fn: _ListenerFnType, ) -> str: header = ( diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 88edba328..fa46a46c4 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -49,7 +49,7 @@ class HasDescriptionCode: code = kw.pop("code", None) if code is not None: self.code = code - super(HasDescriptionCode, self).__init__(*arg, **kw) + super().__init__(*arg, **kw) def _code_str(self) -> str: if not self.code: @@ -65,7 +65,7 @@ class HasDescriptionCode: ) def __str__(self) -> str: - message = super(HasDescriptionCode, self).__str__() + message = super().__str__() if self.code: message = "%s %s" % (message, self._code_str()) return message @@ -134,9 +134,7 @@ class ObjectNotExecutableError(ArgumentError): """ def __init__(self, target: Any): - super(ObjectNotExecutableError, self).__init__( - "Not an executable object: %r" % target - ) + super().__init__("Not an executable object: %r" % target) self.target = target def __reduce__(self) -> Union[str, Tuple[Any, ...]]: @@ -223,7 +221,7 @@ class UnsupportedCompilationError(CompileError): element_type: Type[ClauseElement], message: Optional[str] = None, ): - super(UnsupportedCompilationError, self).__init__( + super().__init__( "Compiler %r can't render element of type %s%s" % (compiler, element_type, ": %s" % message if message else "") ) @@ -557,7 +555,7 @@ class DBAPIError(StatementError): dbapi_base_err: Type[Exception], hide_parameters: bool = False, connection_invalidated: bool = False, - dialect: Optional["Dialect"] = None, + dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, ) -> StatementError: ... @@ -572,7 +570,7 @@ class DBAPIError(StatementError): dbapi_base_err: Type[Exception], hide_parameters: bool = False, connection_invalidated: bool = False, - dialect: Optional["Dialect"] = None, + dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, ) -> DontWrapMixin: ... @@ -587,7 +585,7 @@ class DBAPIError(StatementError): dbapi_base_err: Type[Exception], hide_parameters: bool = False, connection_invalidated: bool = False, - dialect: Optional["Dialect"] = None, + dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, ) -> BaseException: ... @@ -601,7 +599,7 @@ class DBAPIError(StatementError): dbapi_base_err: Type[Exception], hide_parameters: bool = False, connection_invalidated: bool = False, - dialect: Optional["Dialect"] = None, + dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, ) -> Union[BaseException, DontWrapMixin]: # Don't ever wrap these, just return them directly as if @@ -792,7 +790,7 @@ class Base20DeprecationWarning(SADeprecationWarning): def __str__(self) -> str: return ( - super(Base20DeprecationWarning, self).__str__() + super().__str__() + " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)" ) diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index bfec09137..f4adf3d29 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -1071,7 +1071,7 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]): if obj is None: return self else: - return super(AmbiguousAssociationProxyInstance, self).get(obj) + return super().get(obj) def __eq__(self, obj: object) -> NoReturn: self._ambiguous() diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 6d441c9e3..6eb30ba4c 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -884,12 +884,12 @@ class AutomapBase: cls.metadata.reflect(autoload_with, **opts) with _CONFIGURE_MUTEX: - table_to_map_config = dict( - (m.local_table, m) + table_to_map_config = { + m.local_table: m for m in _DeferredMapperConfig.classes_for_base( cls, sort=False ) - ) + } many_to_many = [] diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 7093de732..48e57e2bc 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -525,15 +525,13 @@ class Result: # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set( - [ - _get_params[col].key - for col, value in zip( - mapper.primary_key, primary_key_identity - ) - if value is None - ] - ) + nones = { + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + } _lcl_get_clause = sql_util.adapt_criterion_to_null( _lcl_get_clause, nones ) @@ -562,14 +560,12 @@ class Result: setup, tuple(elem is None for elem in primary_key_identity) ) - params = dict( - [ - (_get_params[primary_key].key, id_val) - for id_val, primary_key in zip( - primary_key_identity, mapper.primary_key - ) - ] - ) + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } result = list(bq.for_session(self.session).params(**params)) l = len(result) diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 7afe2343d..8f6e2ffcd 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -28,7 +28,7 @@ __all__ = ["ShardedSession", "ShardedQuery"] class ShardedQuery(Query): def __init__(self, *args, **kwargs): - super(ShardedQuery, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.id_chooser = self.session.id_chooser self.query_chooser = self.session.query_chooser self.execute_chooser = self.session.execute_chooser @@ -88,7 +88,7 @@ class ShardedSession(Session): """ query_chooser = kwargs.pop("query_chooser", None) - super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) + super().__init__(query_cls=query_cls, **kwargs) event.listen( self, "do_orm_execute", execute_and_instances, retval=True @@ -138,7 +138,7 @@ class ShardedSession(Session): """ if identity_token is not None: - return super(ShardedSession, self)._identity_lookup( + return super()._identity_lookup( mapper, primary_key_identity, identity_token=identity_token, @@ -149,7 +149,7 @@ class ShardedSession(Session): if lazy_loaded_from: q = q._set_lazyload_from(lazy_loaded_from) for shard_id in self.id_chooser(q, primary_key_identity): - obj = super(ShardedSession, self)._identity_lookup( + obj = super()._identity_lookup( mapper, primary_key_identity, identity_token=shard_id, diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index 5c5d26736..ce6312365 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -278,13 +278,9 @@ class index_property(hybrid_property): # noqa """ if mutable: - super(index_property, self).__init__( - self.fget, self.fset, self.fdel, self.expr - ) + super().__init__(self.fget, self.fset, self.fdel, self.expr) else: - super(index_property, self).__init__( - self.fget, None, None, self.expr - ) + super().__init__(self.fget, None, None, self.expr) self.attr_name = attr_name self.index = index self.default = default diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 427e151da..f36087ad9 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -165,7 +165,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): return factories def unregister(self, class_): - super(ExtendedInstrumentationRegistry, self).unregister(class_) + super().unregister(class_) if class_ in self._manager_finders: del self._manager_finders[class_] del self._state_finders[class_] @@ -321,7 +321,7 @@ class _ClassInstrumentationAdapter(ClassManager): self._adapted.instrument_attribute(self.class_, key, inst) def post_configure_attribute(self, key): - super(_ClassInstrumentationAdapter, self).post_configure_attribute(key) + super().post_configure_attribute(key) self._adapted.post_configure_attribute(self.class_, key, self[key]) def install_descriptor(self, key, inst): diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index bfc3459d0..f392a85a7 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -63,7 +63,7 @@ def apply_mypy_mapped_attr( ): break else: - util.fail(api, "Can't find mapped attribute {}".format(name), cls) + util.fail(api, f"Can't find mapped attribute {name}", cls) return None if stmt.type is None: diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 44a1768a8..a32bc9b52 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -90,7 +90,7 @@ class SQLAlchemyAttribute: info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface, - ) -> "SQLAlchemyAttribute": + ) -> SQLAlchemyAttribute: data = data.copy() typ = deserialize_and_fixup_type(data.pop("type"), api) return cls(typ=typ, info=info, **data) @@ -238,8 +238,7 @@ def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: and isinstance(stmt.expr[0], NameExpr) and stmt.expr[0].fullname == "typing.TYPE_CHECKING" ): - for substmt in stmt.body[0].body: - yield substmt + yield from stmt.body[0].body else: yield stmt diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index f08ffc68d..b0615d95d 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -143,7 +143,7 @@ def ordering_list( count_from: Optional[int] = None, ordering_func: Optional[OrderingFunc] = None, reorder_on_append: bool = False, -) -> Callable[[], "OrderingList"]: +) -> Callable[[], OrderingList]: """Prepares an :class:`OrderingList` factory for use in mapper definitions. Returns an object suitable for use as an argument to a Mapper @@ -335,29 +335,29 @@ class OrderingList(List[_T]): self._set_order_value(entity, should_be) def append(self, entity): - super(OrderingList, self).append(entity) + super().append(entity) self._order_entity(len(self) - 1, entity, self.reorder_on_append) def _raw_append(self, entity): """Append without any ordering behavior.""" - super(OrderingList, self).append(entity) + super().append(entity) _raw_append = collection.adds(1)(_raw_append) def insert(self, index, entity): - super(OrderingList, self).insert(index, entity) + super().insert(index, entity) self._reorder() def remove(self, entity): - super(OrderingList, self).remove(entity) + super().remove(entity) adapter = collection_adapter(self) if adapter and adapter._referenced_by_owner: self._reorder() def pop(self, index=-1): - entity = super(OrderingList, self).pop(index) + entity = super().pop(index) self._reorder() return entity @@ -375,18 +375,18 @@ class OrderingList(List[_T]): self.__setitem__(i, entity[i]) else: self._order_entity(index, entity, True) - super(OrderingList, self).__setitem__(index, entity) + super().__setitem__(index, entity) def __delitem__(self, index): - super(OrderingList, self).__delitem__(index) + super().__delitem__(index) self._reorder() def __setslice__(self, start, end, values): - super(OrderingList, self).__setslice__(start, end, values) + super().__setslice__(start, end, values) self._reorder() def __delslice__(self, start, end): - super(OrderingList, self).__delslice__(start, end) + super().__delslice__(start, end) self._reorder() def __reduce__(self): diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index f7050b93f..dd295c3ed 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -63,10 +63,10 @@ def _add_default_handler(logger: logging.Logger) -> None: logger.addHandler(handler) -_logged_classes: Set[Type["Identified"]] = set() +_logged_classes: Set[Type[Identified]] = set() -def _qual_logger_name_for_cls(cls: Type["Identified"]) -> str: +def _qual_logger_name_for_cls(cls: Type[Identified]) -> str: return ( getattr(cls, "_sqla_logger_namespace", None) or cls.__module__ + "." + cls.__name__ diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 854bad986..2c77111c1 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1200,7 +1200,7 @@ class ScalarAttributeImpl(AttributeImpl): __slots__ = "_replace_token", "_append_token", "_remove_token" def __init__(self, *arg, **kw): - super(ScalarAttributeImpl, self).__init__(*arg, **kw) + super().__init__(*arg, **kw) self._replace_token = self._append_token = AttributeEventToken( self, OP_REPLACE ) @@ -1628,7 +1628,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): compare_function=None, **kwargs, ): - super(CollectionAttributeImpl, self).__init__( + super().__init__( class_, key, callable_, diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index cfe488003..181dbd4a2 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -242,11 +242,11 @@ def _bulk_update( search_keys = {mapper._version_id_prop.key}.union(search_keys) def _changed_dict(mapper, state): - return dict( - (k, v) + return { + k: v for k, v in state.dict.items() if k in state.committed_state or k in search_keys - ) + } if isstates: if update_changed_only: @@ -1701,7 +1701,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): value_evaluators[key] = _evaluator evaluated_keys = list(value_evaluators.keys()) - attrib = set(k for k, v in resolved_keys_as_propnames) + attrib = {k for k, v in resolved_keys_as_propnames} states = set() for obj, state, dict_ in matched_objects: diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 99a51c998..b957dc5d4 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -187,9 +187,9 @@ class _MultipleClassMarker(ClsRegistryToken): on_remove: Optional[Callable[[], None]] = None, ): self.on_remove = on_remove - self.contents = set( - [weakref.ref(item, self._remove_item) for item in classes] - ) + self.contents = { + weakref.ref(item, self._remove_item) for item in classes + } _registries.add(self) def remove_item(self, cls: Type[Any]) -> None: @@ -224,13 +224,11 @@ class _MultipleClassMarker(ClsRegistryToken): # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, # [ticket:3208] - modules = set( - [ - cls.__module__ - for cls in [ref() for ref in self.contents] - if cls is not None - ] - ) + modules = { + cls.__module__ + for cls in [ref() for ref in self.contents] + if cls is not None + } if item.__module__ in modules: util.warn( "This declarative base already contains a class with the " diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 60ccecdb7..621b3e5d7 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -838,12 +838,10 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): return self def get_children(self, **kw): - for elem in itertools.chain.from_iterable( + yield from itertools.chain.from_iterable( element._from_objects for element in self._raw_columns - ): - yield elem - for elem in super(FromStatement, self).get_children(**kw): - yield elem + ) + yield from super().get_children(**kw) @property def _all_selected_columns(self): @@ -1245,14 +1243,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ): ens = element._annotations["entity_namespace"] if not ens.is_mapper and not ens.is_aliased_class: - for elem in _select_iterables([element]): - yield elem + yield from _select_iterables([element]) else: - for elem in _select_iterables(ens._all_column_expressions): - yield elem + yield from _select_iterables(ens._all_column_expressions) else: - for elem in _select_iterables([element]): - yield elem + yield from _select_iterables([element]) @classmethod def get_columns_clause_froms(cls, statement): diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index c233298b9..268a1d57a 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -342,9 +342,7 @@ class _ImperativeMapperConfig(_MapperConfig): table: Optional[FromClause], mapper_kw: _MapperKwArgs, ): - super(_ImperativeMapperConfig, self).__init__( - registry, cls_, mapper_kw - ) + super().__init__(registry, cls_, mapper_kw) self.local_table = self.set_cls_attribute("__table__", table) @@ -480,7 +478,7 @@ class _ClassScanMapperConfig(_MapperConfig): self.clsdict_view = ( util.immutabledict(dict_) if dict_ else util.EMPTY_DICT ) - super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw) + super().__init__(registry, cls_, mapper_kw) self.registry = registry self.persist_selectable = None @@ -1636,13 +1634,11 @@ class _ClassScanMapperConfig(_MapperConfig): inherited_table = inherited_mapper.local_table if "exclude_properties" not in mapper_args: - mapper_args["exclude_properties"] = exclude_properties = set( - [ - c.key - for c in inherited_table.c - if c not in inherited_mapper._columntoproperty - ] - ).union(inherited_mapper.exclude_properties or ()) + mapper_args["exclude_properties"] = exclude_properties = { + c.key + for c in inherited_table.c + if c not in inherited_mapper._columntoproperty + }.union(inherited_mapper.exclude_properties or ()) exclude_properties.difference_update( [c.key for c in self.declared_columns] ) @@ -1758,7 +1754,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): if not sort: return classes_for_base - all_m_by_cls = dict((m.cls, m) for m in classes_for_base) + all_m_by_cls = {m.cls: m for m in classes_for_base} tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = [] for m_cls in all_m_by_cls: @@ -1771,7 +1767,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: self._configs.pop(self._cls, None) - return super(_DeferredMapperConfig, self).map(mapper_kw) + return super().map(mapper_kw) def _add_attribute( diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 73e2ee934..32de155a1 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -111,7 +111,7 @@ class InstrumentationEvents(event.Events): @classmethod def _clear(cls): - super(InstrumentationEvents, cls)._clear() + super()._clear() instrumentation._instrumentation_factory.dispatch._clear() def class_instrument(self, cls): @@ -266,7 +266,7 @@ class InstanceEvents(event.Events): @classmethod def _clear(cls): - super(InstanceEvents, cls)._clear() + super()._clear() _InstanceEventsHold._clear() def first_init(self, manager, cls): @@ -798,7 +798,7 @@ class MapperEvents(event.Events): @classmethod def _clear(cls): - super(MapperEvents, cls)._clear() + super()._clear() _MapperEventsHold._clear() def instrument_class(self, mapper, class_): diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 33de2aee9..dfe09fcbd 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -399,8 +399,7 @@ class ClassManager( if mgr is not None and mgr is not self: yield mgr if recursive: - for m in mgr.subclass_managers(True): - yield m + yield from mgr.subclass_managers(True) def post_configure_attribute(self, key): _instrumentation_factory.dispatch.attribute_instrument( diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 64f2542fd..edfa61287 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -535,15 +535,11 @@ def load_on_pk_identity( # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set( - [ - _get_params[col].key - for col, value in zip( - mapper.primary_key, primary_key_identity - ) - if value is None - ] - ) + nones = { + _get_params[col].key + for col, value in zip(mapper.primary_key, primary_key_identity) + if value is None + } _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones) @@ -558,14 +554,12 @@ def load_on_pk_identity( sql_util._deep_annotate(_get_clause, {"_orm_adapt": True}), ) - params = dict( - [ - (_get_params[primary_key].key, id_val) - for id_val, primary_key in zip( - primary_key_identity, mapper.primary_key - ) - ] - ) + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } else: params = None diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py index 1aa864f7e..8bacb87df 100644 --- a/lib/sqlalchemy/orm/mapped_collection.py +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -175,7 +175,7 @@ class _AttrGetter: def attribute_keyed_dict( attr_name: str, *, ignore_unpopulated_attribute: bool = False -) -> Type["KeyFuncDict"]: +) -> Type[KeyFuncDict]: """A dictionary-based collection type with attribute-based keying. .. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to @@ -226,7 +226,7 @@ def keyfunc_mapping( keyfunc: Callable[[Any], _KT], *, ignore_unpopulated_attribute: bool = False, -) -> Type["KeyFuncDict[_KT, Any]"]: +) -> Type[KeyFuncDict[_KT, Any]]: """A dictionary-based collection type with arbitrary keying. .. versionchanged:: 2.0 Renamed :data:`.mapped_collection` to diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5f7ff43e4..d15c882c4 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -143,8 +143,7 @@ def _all_registries() -> Set[registry]: def _unconfigured_mappers() -> Iterator[Mapper[Any]]: for reg in _all_registries(): - for mapper in reg._mappers_to_configure(): - yield mapper + yield from reg._mappers_to_configure() _already_compiling = False @@ -905,8 +904,8 @@ class Mapper( with_polymorphic: Optional[ Tuple[ - Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], - Optional["FromClause"], + Union[Literal["*"], Sequence[Union[Mapper[Any], Type[Any]]]], + Optional[FromClause], ] ] @@ -2518,105 +2517,85 @@ class Mapper( @HasMemoized_ro_memoized_attribute def _insert_cols_evaluating_none(self): - return dict( - ( - table, - frozenset( - col for col in columns if col.type.should_evaluate_none - ), + return { + table: frozenset( + col for col in columns if col.type.should_evaluate_none ) for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _insert_cols_as_none(self): - return dict( - ( - table, - frozenset( - col.key - for col in columns - if not col.primary_key - and not col.server_default - and not col.default - and not col.type.should_evaluate_none - ), + return { + table: frozenset( + col.key + for col in columns + if not col.primary_key + and not col.server_default + and not col.default + and not col.type.should_evaluate_none ) for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _propkey_to_col(self): - return dict( - ( - table, - dict( - (self._columntoproperty[col].key, col) for col in columns - ), - ) + return { + table: {self._columntoproperty[col].key: col for col in columns} for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _pk_keys_by_table(self): - return dict( - (table, frozenset([col.key for col in pks])) + return { + table: frozenset([col.key for col in pks]) for table, pks in self._pks_by_table.items() - ) + } @HasMemoized.memoized_attribute def _pk_attr_keys_by_table(self): - return dict( - ( - table, - frozenset([self._columntoproperty[col].key for col in pks]), - ) + return { + table: frozenset([self._columntoproperty[col].key for col in pks]) for table, pks in self._pks_by_table.items() - ) + } @HasMemoized.memoized_attribute def _server_default_cols( self, ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: - return dict( - ( - table, - frozenset( - [ - col - for col in cast("Iterable[Column[Any]]", columns) - if col.server_default is not None - or ( - col.default is not None - and col.default.is_clause_element - ) - ] - ), + return { + table: frozenset( + [ + col + for col in cast("Iterable[Column[Any]]", columns) + if col.server_default is not None + or ( + col.default is not None + and col.default.is_clause_element + ) + ] ) for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _server_onupdate_default_cols( self, ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: - return dict( - ( - table, - frozenset( - [ - col - for col in cast("Iterable[Column[Any]]", columns) - if col.server_onupdate is not None - or ( - col.onupdate is not None - and col.onupdate.is_clause_element - ) - ] - ), + return { + table: frozenset( + [ + col + for col in cast("Iterable[Column[Any]]", columns) + if col.server_onupdate is not None + or ( + col.onupdate is not None + and col.onupdate.is_clause_element + ) + ] ) for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]: diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index dfb61c28a..77532f323 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -458,12 +458,12 @@ def _collect_update_commands( if bulk: # keys here are mapped attribute keys, so # look at mapper attribute keys for pk - params = dict( - (propkey_to_col[propkey].key, state_dict[propkey]) + params = { + propkey_to_col[propkey].key: state_dict[propkey] for propkey in set(propkey_to_col) .intersection(state_dict) .difference(mapper._pk_attr_keys_by_table[table]) - ) + } has_all_defaults = True else: params = {} @@ -542,12 +542,12 @@ def _collect_update_commands( if bulk: # keys here are mapped attribute keys, so # look at mapper attribute keys for pk - pk_params = dict( - (propkey_to_col[propkey]._label, state_dict.get(propkey)) + pk_params = { + propkey_to_col[propkey]._label: state_dict.get(propkey) for propkey in set(propkey_to_col).intersection( mapper._pk_attr_keys_by_table[table] ) - ) + } else: pk_params = {} for col in pks: @@ -1689,7 +1689,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): def _sort_states(mapper, states): pending = set(states) - persistent = set(s for s in pending if s.key is not None) + persistent = {s for s in pending if s.key is not None} pending.difference_update(persistent) try: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 0cbd3f713..c1da267f4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -151,9 +151,7 @@ class ColumnProperty( doc: Optional[str] = None, _instrument: bool = True, ): - super(ColumnProperty, self).__init__( - attribute_options=attribute_options - ) + super().__init__(attribute_options=attribute_options) columns = (column,) + additional_columns self.columns = [ coercions.expect(roles.LabeledColumnExprRole, c) for c in columns @@ -211,7 +209,7 @@ class ColumnProperty( column.name = key @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: return self @property @@ -601,7 +599,7 @@ class MappedColumn( return self.column.name @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: if self.deferred: return ColumnProperty( self.column, diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 2d97754f4..0d8d21df0 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3281,7 +3281,7 @@ class BulkUpdate(BulkUD): values: Dict[_DMLColumnArgument, Any], update_kwargs: Optional[Dict[Any, Any]], ): - super(BulkUpdate, self).__init__(query) + super().__init__(query) self.values = values self.update_kwargs = update_kwargs diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 986093e02..020fae600 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2966,15 +2966,13 @@ class JoinCondition: # 2. columns that are FK but are not remote (e.g. local) # suggest manytoone. - manytoone_local = set( - [ - c - for c in self._gather_columns_with_annotation( - self.primaryjoin, "foreign" - ) - if "remote" not in c._annotations - ] - ) + manytoone_local = { + c + for c in self._gather_columns_with_annotation( + self.primaryjoin, "foreign" + ) + if "remote" not in c._annotations + } # 3. if both collections are present, remove columns that # refer to themselves. This is for the case of @@ -3204,13 +3202,11 @@ class JoinCondition: self, clause: ColumnElement[Any], *annotation: Iterable[str] ) -> Set[ColumnElement[Any]]: annotation_set = set(annotation) - return set( - [ - cast(ColumnElement[Any], col) - for col in visitors.iterate(clause, {}) - if annotation_set.issubset(col._annotations) - ] - ) + return { + cast(ColumnElement[Any], col) + for col in visitors.iterate(clause, {}) + if annotation_set.issubset(col._annotations) + } def join_targets( self, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index b65774f0a..efa0dc680 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -157,7 +157,7 @@ class UninstrumentedColumnLoader(LoaderStrategy): __slots__ = ("columns",) def __init__(self, parent, strategy_key): - super(UninstrumentedColumnLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.columns = self.parent_property.columns def setup_query( @@ -197,7 +197,7 @@ class ColumnLoader(LoaderStrategy): __slots__ = "columns", "is_composite" def __init__(self, parent, strategy_key): - super(ColumnLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.columns = self.parent_property.columns self.is_composite = hasattr(self.parent_property, "composite_class") @@ -285,7 +285,7 @@ class ColumnLoader(LoaderStrategy): @properties.ColumnProperty.strategy_for(query_expression=True) class ExpressionColumnLoader(ColumnLoader): def __init__(self, parent, strategy_key): - super(ExpressionColumnLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) # compare to the "default" expression that is mapped in # the column. If it's sql.null, we don't need to render @@ -381,7 +381,7 @@ class DeferredColumnLoader(LoaderStrategy): __slots__ = "columns", "group", "raiseload" def __init__(self, parent, strategy_key): - super(DeferredColumnLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) if hasattr(self.parent_property, "composite_class"): raise NotImplementedError( "Deferred loading for composite " "types not implemented yet" @@ -582,7 +582,7 @@ class AbstractRelationshipLoader(LoaderStrategy): __slots__ = "mapper", "target", "uselist", "entity" def __init__(self, parent, strategy_key): - super(AbstractRelationshipLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.mapper = self.parent_property.mapper self.entity = self.parent_property.entity self.target = self.parent_property.target @@ -682,7 +682,7 @@ class LazyLoader( def __init__( self, parent: RelationshipProperty[Any], strategy_key: Tuple[Any, ...] ): - super(LazyLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self._raise_always = self.strategy_opts["lazy"] == "raise" self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql" @@ -1431,7 +1431,7 @@ class SubqueryLoader(PostLoader): __slots__ = ("join_depth",) def __init__(self, parent, strategy_key): - super(SubqueryLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth def init_class_attribute(self, mapper): @@ -1560,7 +1560,7 @@ class SubqueryLoader(PostLoader): elif distinct_target_key is None: # if target_cols refer to a non-primary key or only # part of a composite primary key, set the q as distinct - for t in set(c.table for c in target_cols): + for t in {c.table for c in target_cols}: if not set(target_cols).issuperset(t.primary_key): q._distinct = True break @@ -2078,7 +2078,7 @@ class JoinedLoader(AbstractRelationshipLoader): __slots__ = "join_depth", "_aliased_class_pool" def __init__(self, parent, strategy_key): - super(JoinedLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth self._aliased_class_pool = [] @@ -2832,7 +2832,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): _chunksize = 500 def __init__(self, parent, strategy_key): - super(SelectInLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth is_m2o = self.parent_property.direction is interfaces.MANYTOONE diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 23b3466f5..1c48bc476 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1872,7 +1872,7 @@ class _AttributeStrategyLoad(_LoadElement): ), ] - _of_type: Union["Mapper[Any]", "AliasedInsp[Any]", None] + _of_type: Union[Mapper[Any], AliasedInsp[Any], None] _path_with_polymorphic_path: Optional[PathRegistry] is_class_strategy = False diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 5e66653a3..9a8c02b6b 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -411,9 +411,9 @@ class UOWTransaction: if cycles: # if yes, break the per-mapper actions into # per-state actions - convert = dict( - (rec, set(rec.per_state_flush_actions(self))) for rec in cycles - ) + convert = { + rec: set(rec.per_state_flush_actions(self)) for rec in cycles + } # rewrite the existing dependencies to point to # the per-state actions for those per-mapper actions @@ -435,9 +435,9 @@ class UOWTransaction: for dep in convert[edge[1]]: self.dependencies.add((edge[0], dep)) - return set( - [a for a in self.postsort_actions.values() if not a.disabled] - ).difference(cycles) + return { + a for a in self.postsort_actions.values() if not a.disabled + }.difference(cycles) def execute(self) -> None: postsort_actions = self._generate_actions() @@ -478,9 +478,9 @@ class UOWTransaction: return states = set(self.states) - isdel = set( + isdel = { s for (s, (isdelete, listonly)) in self.states.items() if isdelete - ) + } other = states.difference(isdel) if isdel: self.session._remove_newly_deleted(isdel) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 50eba5d4c..e5bdbaa4f 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1011,7 +1011,7 @@ class AliasedInsp( our_classes = util.to_set( mp.class_ for mp in self.with_polymorphic_mappers ) - new_classes = set([mp.class_ for mp in other.with_polymorphic_mappers]) + new_classes = {mp.class_ for mp in other.with_polymorphic_mappers} if our_classes == new_classes: return other else: @@ -1278,8 +1278,7 @@ class LoaderCriteriaOption(CriteriaOption): def _all_mappers(self) -> Iterator[Mapper[Any]]: if self.entity: - for mp_ent in self.entity.mapper.self_and_descendants: - yield mp_ent + yield from self.entity.mapper.self_and_descendants else: assert self.root_entity stack = list(self.root_entity.__subclasses__()) @@ -1290,8 +1289,7 @@ class LoaderCriteriaOption(CriteriaOption): inspection.inspect(subclass, raiseerr=False), ) if ent: - for mp in ent.mapper.self_and_descendants: - yield mp + yield from ent.mapper.self_and_descendants else: stack.extend(subclass.__subclasses__()) diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 8b8f6b010..7c5281bee 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -1127,7 +1127,7 @@ def label( name: str, element: _ColumnExpressionArgument[_T], type_: Optional[_TypeEngineArgument[_T]] = None, -) -> "Label[_T]": +) -> Label[_T]: """Return a :class:`Label` object for the given :class:`_expression.ColumnElement`. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 34b295113..c81891169 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -291,16 +291,14 @@ def _cloned_intersection(a, b): """ all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if all_overlap.intersection(elem._cloned_set) - ) + return {elem for elem in a if all_overlap.intersection(elem._cloned_set)} def _cloned_difference(a, b): all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( + return { elem for elem in a if not all_overlap.intersection(elem._cloned_set) - ) + } class _DialectArgView(MutableMapping[str, Any]): diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 8074bcf8b..f48a3ccb0 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -782,7 +782,7 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl): else: advice = None - return super(ExpressionElementImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) @@ -1096,7 +1096,7 @@ class LabeledColumnExprImpl(ExpressionElementImpl): if isinstance(resolved, roles.ExpressionElementRole): return resolved.label(None) else: - new = super(LabeledColumnExprImpl, self)._implicit_coercions( + new = super()._implicit_coercions( element, resolved, argname=argname, **kw ) if isinstance(new, roles.ExpressionElementRole): @@ -1123,7 +1123,7 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): f"{', '.join(repr(e) for e in element)})?" ) - return super(ColumnsClauseImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) @@ -1370,7 +1370,7 @@ class CompoundElementImpl(_NoTextCoercion, RoleImpl): ) else: advice = None - return super(CompoundElementImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9a00afc91..17aafddad 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -115,104 +115,102 @@ if typing.TYPE_CHECKING: _FromHintsType = Dict["FromClause", str] -RESERVED_WORDS = set( - [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "authorization", - "between", - "binary", - "both", - "case", - "cast", - "check", - "collate", - "column", - "constraint", - "create", - "cross", - "current_date", - "current_role", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "for", - "foreign", - "freeze", - "from", - "full", - "grant", - "group", - "having", - "ilike", - "in", - "initially", - "inner", - "intersect", - "into", - "is", - "isnull", - "join", - "leading", - "left", - "like", - "limit", - "localtime", - "localtimestamp", - "natural", - "new", - "not", - "notnull", - "null", - "off", - "offset", - "old", - "on", - "only", - "or", - "order", - "outer", - "overlaps", - "placing", - "primary", - "references", - "right", - "select", - "session_user", - "set", - "similar", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "verbose", - "when", - "where", - ] -) +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", +} LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I) @@ -505,8 +503,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): "between each element to resolve." ) froms_str = ", ".join( - '"{elem}"'.format(elem=self.froms[from_]) - for from_ in froms + f'"{self.froms[from_]}"' for from_ in froms ) message = template.format( froms=froms_str, start=self.froms[start_with] @@ -1259,11 +1256,8 @@ class SQLCompiler(Compiled): # mypy is not able to see the two value types as the above Union, # it just sees "object". don't know how to resolve - return dict( - ( - key, - value, - ) # type: ignore + return { + key: value # type: ignore for key, value in ( ( self.bind_names[bindparam], @@ -1277,7 +1271,7 @@ class SQLCompiler(Compiled): for bindparam in self.bind_names ) if value is not None - ) + } def is_subquery(self): return len(self.stack) > 1 @@ -4147,17 +4141,12 @@ class SQLCompiler(Compiled): def _setup_select_hints( self, select: Select[Any] ) -> Tuple[str, _FromHintsType]: - byfrom = dict( - [ - ( - from_, - hinttext - % {"name": from_._compiler_dispatch(self, ashint=True)}, - ) - for (from_, dialect), hinttext in select._hints.items() - if dialect in ("*", self.dialect.name) - ] - ) + byfrom = { + from_: hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)} + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + } hint_text = self.get_select_hint_text(byfrom) return hint_text, byfrom @@ -4583,13 +4572,11 @@ class SQLCompiler(Compiled): ) def _setup_crud_hints(self, stmt, table_text): - dialect_hints = dict( - [ - (table, hint_text) - for (table, dialect), hint_text in stmt._hints.items() - if dialect in ("*", self.dialect.name) - ] - ) + dialect_hints = { + table: hint_text + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + } if stmt.table in dialect_hints: table_text = self.format_from_hint_text( table_text, stmt.table, dialect_hints[stmt.table], True @@ -5318,9 +5305,7 @@ class StrSQLCompiler(SQLCompiler): if not isinstance(compiler, StrSQLCompiler): return compiler.process(element) - return super(StrSQLCompiler, self).visit_unsupported_compilation( - element, err - ) + return super().visit_unsupported_compilation(element, err) def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( @@ -6603,14 +6588,14 @@ class IdentifierPreparer: @util.memoized_property def _r_identifiers(self): - initial, final, escaped_final = [ + initial, final, escaped_final = ( re.escape(s) for s in ( self.initial_quote, self.final_quote, self._escape_identifier(self.final_quote), ) - ] + ) r = re.compile( r"(?:" r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 31d127c2c..017ff7baa 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -227,15 +227,15 @@ def _get_crud_params( parameters = {} elif stmt_parameter_tuples: assert spd is not None - parameters = dict( - (_column_as_key(key), REQUIRED) + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys if key not in spd - ) + } else: - parameters = dict( - (_column_as_key(key), REQUIRED) for key in compiler.column_keys - ) + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys + } # create a list of column assignment clauses as tuples values: List[_CrudParamElement] = [] @@ -1278,10 +1278,10 @@ def _get_update_multitable_params( values, kw, ): - normalized_params = dict( - (coercions.expect(roles.DMLColumnRole, c), param) + normalized_params = { + coercions.expect(roles.DMLColumnRole, c): param for c, param in stmt_parameter_tuples - ) + } include_table = compile_state.include_table_with_column_exprs diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index fa0c25b1d..ecdc2eb63 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -176,7 +176,7 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): """ _ddl_if: Optional[DDLIf] = None - target: Optional["SchemaItem"] = None + target: Optional[SchemaItem] = None def _execute_on_connection( self, connection, distilled_params, execution_options @@ -1179,12 +1179,10 @@ class SchemaDropper(InvokeDropDDLBase): def sort_tables( - tables: Iterable["Table"], - skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, - extra_dependencies: Optional[ - typing_Sequence[Tuple["Table", "Table"]] - ] = None, -) -> List["Table"]: + tables: Iterable[Table], + skip_fn: Optional[Callable[[ForeignKeyConstraint], bool]] = None, + extra_dependencies: Optional[typing_Sequence[Tuple[Table, Table]]] = None, +) -> List[Table]: """Sort a collection of :class:`_schema.Table` objects based on dependency. diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 2d3e3598b..c279e344b 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -1179,7 +1179,7 @@ class Insert(ValuesBase): ) def __init__(self, table: _DMLTableArgument): - super(Insert, self).__init__(table) + super().__init__(table) @_generative def inline(self: SelfInsert) -> SelfInsert: @@ -1498,7 +1498,7 @@ class Update(DMLWhereBase, ValuesBase): ) def __init__(self, table: _DMLTableArgument): - super(Update, self).__init__(table) + super().__init__(table) @_generative def ordered_values( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 044bdf585..d9a1a9358 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3035,7 +3035,7 @@ class BooleanClauseList(ExpressionClauseList[bool]): if not self.clauses: return self else: - return super(BooleanClauseList, self).self_group(against=against) + return super().self_group(against=against) and_ = BooleanClauseList.and_ @@ -3082,7 +3082,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): ] self.type = sqltypes.TupleType(*[arg.type for arg in init_clauses]) - super(Tuple, self).__init__(*init_clauses) + super().__init__(*init_clauses) @property def _select_iterable(self) -> _SelectIterable: @@ -3753,8 +3753,8 @@ class BinaryExpression(OperatorExpression[_T]): if typing.TYPE_CHECKING: def __invert__( - self: "BinaryExpression[_T]", - ) -> "BinaryExpression[_T]": + self: BinaryExpression[_T], + ) -> BinaryExpression[_T]: ... @util.ro_non_memoized_property @@ -3772,7 +3772,7 @@ class BinaryExpression(OperatorExpression[_T]): modifiers=self.modifiers, ) else: - return super(BinaryExpression, self)._negate() + return super()._negate() class Slice(ColumnElement[Any]): @@ -4617,7 +4617,7 @@ class ColumnClause( if self.table is not None: return self.table.entity_namespace else: - return super(ColumnClause, self).entity_namespace + return super().entity_namespace def _clone(self, detect_subquery_cols=False, **kw): if ( @@ -4630,7 +4630,7 @@ class ColumnClause( new = table.c.corresponding_column(self) return new - return super(ColumnClause, self)._clone(**kw) + return super()._clone(**kw) @HasMemoized_ro_memoized_attribute def _from_objects(self) -> List[FromClause]: @@ -4993,7 +4993,7 @@ class AnnotatedColumnElement(Annotated): self.__dict__.pop(attr) def _with_annotations(self, values): - clone = super(AnnotatedColumnElement, self)._with_annotations(values) + clone = super()._with_annotations(values) clone.__dict__.pop("comparator", None) return clone @@ -5032,7 +5032,7 @@ class _truncated_label(quoted_name): def __new__(cls, value: str, quote: Optional[bool] = None) -> Any: quote = getattr(value, "quote", quote) # return super(_truncated_label, cls).__new__(cls, value, quote, True) - return super(_truncated_label, cls).__new__(cls, value, quote) + return super().__new__(cls, value, quote) def __reduce__(self) -> Any: return self.__class__, (str(self), self.quote) diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fad7c28eb..5ed89bc82 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -167,9 +167,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): @property def _proxy_key(self): - return super(FunctionElement, self)._proxy_key or getattr( - self, "name", None - ) + return super()._proxy_key or getattr(self, "name", None) def _execute_on_connection( self, @@ -660,7 +658,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ): return Grouping(self) else: - return super(FunctionElement, self).self_group(against=against) + return super().self_group(against=against) @property def entity_namespace(self): @@ -1198,7 +1196,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]): ] kwargs.setdefault("type_", _type_from_args(fn_args)) kwargs["_parsed_args"] = fn_args - super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs) + super().__init__(*fn_args, **kwargs) class coalesce(ReturnTypeFromArgs[_T]): @@ -1304,7 +1302,7 @@ class count(GenericFunction[int]): def __init__(self, expression=None, **kwargs): if expression is None: expression = literal_column("*") - super(count, self).__init__(expression, **kwargs) + super().__init__(expression, **kwargs) class current_date(AnsiFunction[datetime.date]): @@ -1411,7 +1409,7 @@ class array_agg(GenericFunction[_T]): type_from_args, dimensions=1 ) kwargs["_parsed_args"] = fn_args - super(array_agg, self).__init__(*fn_args, **kwargs) + super().__init__(*fn_args, **kwargs) class OrderedSetAgg(GenericFunction[_T]): diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index bbfaf47e1..26e3a21bb 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -439,7 +439,7 @@ class DeferredLambdaElement(LambdaElement): lambda_args: Tuple[Any, ...] = (), ): self.lambda_args = lambda_args - super(DeferredLambdaElement, self).__init__(fn, role, opts) + super().__init__(fn, role, opts) def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) @@ -483,7 +483,7 @@ class DeferredLambdaElement(LambdaElement): def _copy_internals( self, clone=_clone, deferred_copy_internals=None, **kw ): - super(DeferredLambdaElement, self)._copy_internals( + super()._copy_internals( clone=clone, deferred_copy_internals=deferred_copy_internals, # **kw opts=kw, diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 55c275741..2d1f9caa1 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -66,11 +66,11 @@ class OperatorType(Protocol): def __call__( self, - left: "Operators", + left: Operators, right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> "Operators": + ) -> Operators: ... @@ -184,7 +184,7 @@ class Operators: precedence: int = 0, is_comparison: bool = False, return_type: Optional[ - Union[Type["TypeEngine[Any]"], "TypeEngine[Any]"] + Union[Type[TypeEngine[Any]], TypeEngine[Any]] ] = None, python_impl: Optional[Callable[..., Any]] = None, ) -> Callable[[Any], Operators]: @@ -397,7 +397,7 @@ class custom_op(OperatorType, Generic[_T]): precedence: int = 0, is_comparison: bool = False, return_type: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] + Union[Type[TypeEngine[_T]], TypeEngine[_T]] ] = None, natural_self_precedent: bool = False, eager_grouping: bool = False, diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index cd10d0c4a..f76fc447c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -920,11 +920,11 @@ class Table( :attr:`_schema.Table.indexes` """ - return set( + return { fkc.constraint for fkc in self.foreign_keys if fkc.constraint is not None - ) + } def _init_existing(self, *args: Any, **kwargs: Any) -> None: autoload_with = kwargs.pop("autoload_with", None) @@ -1895,7 +1895,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): # name = None is expected to be an interim state # note this use case is legacy now that ORM declarative has a # dedicated "column" construct local to the ORM - super(Column, self).__init__(name, type_) # type: ignore + super().__init__(name, type_) # type: ignore self.key = key if key is not None else name # type: ignore self.primary_key = primary_key @@ -3573,7 +3573,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: column = parent assert isinstance(column, Column) - super(Sequence, self)._set_parent(column) + super()._set_parent(column) column._on_table_attach(self._set_table) def _copy(self) -> Sequence: @@ -3712,7 +3712,7 @@ class DefaultClause(FetchedValue): _reflected: bool = False, ) -> None: util.assert_arg_type(arg, (str, ClauseElement, TextClause), "arg") - super(DefaultClause, self).__init__(for_update) + super().__init__(for_update) self.arg = arg self.reflected = _reflected @@ -3914,9 +3914,9 @@ class ColumnCollectionMixin: # issue #3411 - don't do the per-column auto-attach if some of the # columns are specified as strings. - has_string_cols = set( + has_string_cols = { c for c in self._pending_colargs if c is not None - ).difference(col_objs) + }.difference(col_objs) if not has_string_cols: def _col_attached(column: Column[Any], table: Table) -> None: @@ -4434,7 +4434,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): return self.elements[0].column.table def _validate_dest_table(self, table: Table) -> None: - table_keys = set([elem._table_key() for elem in self.elements]) + table_keys = {elem._table_key() for elem in self.elements} if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( @@ -4624,7 +4624,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): **dialect_kw: Any, ) -> None: self._implicit_generated = _implicit_generated - super(PrimaryKeyConstraint, self).__init__( + super().__init__( *columns, name=name, deferrable=deferrable, @@ -4636,7 +4636,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: table = parent assert isinstance(table, Table) - super(PrimaryKeyConstraint, self)._set_parent(table) + super()._set_parent(table) if table.primary_key is not self: table.constraints.discard(table.primary_key) @@ -5219,13 +5219,9 @@ class MetaData(HasSchemaAttr): for fk in removed.foreign_keys: fk._remove_from_metadata(self) if self._schemas: - self._schemas = set( - [ - t.schema - for t in self.tables.values() - if t.schema is not None - ] - ) + self._schemas = { + t.schema for t in self.tables.values() if t.schema is not None + } def __getstate__(self) -> Dict[str, Any]: return { diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8c64dea9d..fcffc324f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1301,12 +1301,12 @@ class Join(roles.DMLTableRole, FromClause): # run normal _copy_internals. the clones for # left and right will come from the clone function's # cache - super(Join, self)._copy_internals(clone=clone, **kw) + super()._copy_internals(clone=clone, **kw) self._reset_memoizations() def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: - super(Join, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) @@ -1467,7 +1467,7 @@ class Join(roles.DMLTableRole, FromClause): # "consider_as_foreign_keys". if consider_as_foreign_keys: for const in list(constraints): - if set(f.parent for f in const.elements) != set( + if {f.parent for f in const.elements} != set( consider_as_foreign_keys ): del constraints[const] @@ -1475,7 +1475,7 @@ class Join(roles.DMLTableRole, FromClause): # if still multiple constraints, but # they all refer to the exact same end result, use it. if len(constraints) > 1: - dedupe = set(tuple(crit) for crit in constraints.values()) + dedupe = {tuple(crit) for crit in constraints.values()} if len(dedupe) == 1: key = list(constraints)[0] constraints = {key: constraints[key]} @@ -1621,7 +1621,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): self.name = name def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: - super(AliasedReturnsRows, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) self.element._refresh_for_new_column(column) def _populate_column_collection(self): @@ -1654,7 +1654,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): ) -> None: existing_element = self.element - super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) + super()._copy_internals(clone=clone, **kw) # the element clone is usually against a Table that returns the # same object. don't reset exported .c. collections and other @@ -1752,7 +1752,7 @@ class TableValuedAlias(LateralFromClause, Alias): table_value_type=None, joins_implicitly=False, ): - super(TableValuedAlias, self)._init(selectable, name=name) + super()._init(selectable, name=name) self.joins_implicitly = joins_implicitly self._tableval_type = ( @@ -1959,7 +1959,7 @@ class TableSample(FromClauseAlias): self.sampling = sampling self.seed = seed - super(TableSample, self)._init(selectable, name=name) + super()._init(selectable, name=name) def _get_method(self): return self.sampling @@ -2044,7 +2044,7 @@ class CTE( self._prefixes = _prefixes if _suffixes: self._suffixes = _suffixes - super(CTE, self)._init(selectable, name=name) + super()._init(selectable, name=name) def _populate_column_collection(self): if self._cte_alias is not None: @@ -2945,7 +2945,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): return None def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any): - super(TableClause, self).__init__() + super().__init__() self.name = name self._columns = DedupeColumnCollection() self.primary_key = ColumnSet() # type: ignore @@ -3156,7 +3156,7 @@ class Values(Generative, LateralFromClause): name: Optional[str] = None, literal_binds: bool = False, ): - super(Values, self).__init__() + super().__init__() self._column_args = columns if name is None: self._unnamed = True @@ -4188,7 +4188,7 @@ class CompoundSelectState(CompileState): # TODO: this is hacky and slow hacky_subquery = self.statement.subquery() hacky_subquery.named_with_column = False - d = dict((c.key, c) for c in hacky_subquery.c) + d = {c.key: c for c in hacky_subquery.c} return d, d, d @@ -4369,7 +4369,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): ) def _refresh_for_new_column(self, column): - super(CompoundSelect, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) for select in self.selects: select._refresh_for_new_column(column) @@ -4689,16 +4689,16 @@ class SelectState(util.MemoizedSlots, CompileState): Dict[str, ColumnElement[Any]], Dict[str, ColumnElement[Any]], ]: - with_cols: Dict[str, ColumnElement[Any]] = dict( - (c._tq_label or c.key, c) # type: ignore + with_cols: Dict[str, ColumnElement[Any]] = { + c._tq_label or c.key: c # type: ignore for c in self.statement._all_selected_columns if c._allow_label_resolve - ) - only_froms: Dict[str, ColumnElement[Any]] = dict( - (c.key, c) # type: ignore + } + only_froms: Dict[str, ColumnElement[Any]] = { + c.key: c # type: ignore for c in _select_iterables(self.froms) if c._allow_label_resolve - ) + } only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy() for key, value in only_froms.items(): with_cols.setdefault(key, value) @@ -5569,7 +5569,7 @@ class Select( # 2. copy FROM collections, adding in joins that we've created. existing_from_obj = [clone(f, **kw) for f in self._from_obj] add_froms = ( - set(f for f in new_froms.values() if isinstance(f, Join)) + {f for f in new_froms.values() if isinstance(f, Join)} .difference(all_the_froms) .difference(existing_from_obj) ) @@ -5589,15 +5589,13 @@ class Select( # correlate_except, setup_joins, these clone normally. For # column-expression oriented things like raw_columns, where_criteria, # order by, we get this from the new froms. - super(Select, self)._copy_internals( - clone=clone, omit_attrs=("_from_obj",), **kw - ) + super()._copy_internals(clone=clone, omit_attrs=("_from_obj",), **kw) self._reset_memoizations() def get_children(self, **kw: Any) -> Iterable[ClauseElement]: return itertools.chain( - super(Select, self).get_children( + super().get_children( omit_attrs=("_from_obj", "_correlate", "_correlate_except"), **kw, ), diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index b98a16b6f..624b7d16e 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -134,9 +134,7 @@ class Concatenable(TypeEngineMixin): ): return operators.concat_op, self.expr.type else: - return super(Concatenable.Comparator, self)._adapt_expression( - op, other_comparator - ) + return super()._adapt_expression(op, other_comparator) comparator_factory: _ComparatorFactory[Any] = Comparator @@ -319,7 +317,7 @@ class Unicode(String): Parameters are the same as that of :class:`.String`. """ - super(Unicode, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class UnicodeText(Text): @@ -344,7 +342,7 @@ class UnicodeText(Text): Parameters are the same as that of :class:`_expression.TextClause`. """ - super(UnicodeText, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class Integer(HasExpressionLookup, TypeEngine[int]): @@ -930,7 +928,7 @@ class _Binary(TypeEngine[bytes]): if isinstance(value, str): return self else: - return super(_Binary, self).coerce_compared_value(op, value) + return super().coerce_compared_value(op, value) def get_dbapi_type(self, dbapi): return dbapi.BINARY @@ -1450,7 +1448,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self._valid_lookup[None] = self._object_lookup[None] = None - super(Enum, self).__init__(length=length) + super().__init__(length=length) if self.enum_class: kw.setdefault("name", self.enum_class.__name__.lower()) @@ -1551,9 +1549,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): op: OperatorType, other_comparator: TypeEngine.Comparator[Any], ) -> Tuple[OperatorType, TypeEngine[Any]]: - op, typ = super(Enum.Comparator, self)._adapt_expression( - op, other_comparator - ) + op, typ = super()._adapt_expression(op, other_comparator) if op is operators.concat_op: typ = String(self.type.length) return op, typ @@ -1618,7 +1614,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): def adapt(self, impltype, **kw): kw["_enums"] = self._enums_argument kw["_disable_warnings"] = True - return super(Enum, self).adapt(impltype, **kw) + return super().adapt(impltype, **kw) def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): @@ -1649,7 +1645,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): assert e.table is table def literal_processor(self, dialect): - parent_processor = super(Enum, self).literal_processor(dialect) + parent_processor = super().literal_processor(dialect) def process(value): value = self._db_value_for_elem(value) @@ -1660,7 +1656,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return process def bind_processor(self, dialect): - parent_processor = super(Enum, self).bind_processor(dialect) + parent_processor = super().bind_processor(dialect) def process(value): value = self._db_value_for_elem(value) @@ -1671,7 +1667,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return process def result_processor(self, dialect, coltype): - parent_processor = super(Enum, self).result_processor(dialect, coltype) + parent_processor = super().result_processor(dialect, coltype) def process(value): if parent_processor: @@ -1690,7 +1686,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): if self.enum_class: return self.enum_class else: - return super(Enum, self).python_type + return super().python_type class PickleType(TypeDecorator[object]): @@ -1739,7 +1735,7 @@ class PickleType(TypeDecorator[object]): self.protocol = protocol self.pickler = pickler or pickle self.comparator = comparator - super(PickleType, self).__init__() + super().__init__() if impl: # custom impl is not necessarily a LargeBinary subclass. @@ -2000,7 +1996,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): support a "day precision" parameter, i.e. Oracle. """ - super(Interval, self).__init__() + super().__init__() self.native = native self.second_precision = second_precision self.day_precision = day_precision @@ -3005,7 +3001,7 @@ class ARRAY( def _set_parent_with_dispatch(self, parent): """Support SchemaEventTarget""" - super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True) + super()._set_parent_with_dispatch(parent, outer=True) if isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent_with_dispatch(parent) @@ -3249,7 +3245,7 @@ class TIMESTAMP(DateTime): """ - super(TIMESTAMP, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) def get_dbapi_type(self, dbapi): return dbapi.TIMESTAMP @@ -3464,7 +3460,7 @@ class Uuid(TypeEngine[_UUID_RETURN]): @overload def __init__( - self: "Uuid[_python_UUID]", + self: Uuid[_python_UUID], as_uuid: Literal[True] = ..., native_uuid: bool = ..., ): @@ -3472,7 +3468,7 @@ class Uuid(TypeEngine[_UUID_RETURN]): @overload def __init__( - self: "Uuid[str]", + self: Uuid[str], as_uuid: Literal[False] = ..., native_uuid: bool = ..., ): @@ -3628,11 +3624,11 @@ class UUID(Uuid[_UUID_RETURN]): __visit_name__ = "UUID" @overload - def __init__(self: "UUID[_python_UUID]", as_uuid: Literal[True] = ...): + def __init__(self: UUID[_python_UUID], as_uuid: Literal[True] = ...): ... @overload - def __init__(self: "UUID[str]", as_uuid: Literal[False] = ...): + def __init__(self: UUID[str], as_uuid: Literal[False] = ...): ... def __init__(self, as_uuid: bool = True): diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 135407321..866c0ccde 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -301,9 +301,7 @@ class _CopyInternalsTraversal(HasTraversalDispatch): def visit_string_clauseelement_dict( self, attrname, parent, element, clone=_clone, **kw ): - return dict( - (key, clone(value, **kw)) for key, value in element.items() - ) + return {key: clone(value, **kw) for key, value in element.items()} def visit_setup_join_tuple( self, attrname, parent, element, clone=_clone, **kw diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index cd57ee3b6..c3768c6c6 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1399,7 +1399,7 @@ class Emulated(TypeEngineMixin): def _is_native_for_emulated( typ: Type[Union[TypeEngine[Any], TypeEngineMixin]], -) -> TypeGuard["Type[NativeForEmulated]"]: +) -> TypeGuard[Type[NativeForEmulated]]: return hasattr(typ, "adapt_emulated_to_native") @@ -1673,9 +1673,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): if TYPE_CHECKING: assert isinstance(self.expr.type, TypeDecorator) kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types - return super(TypeDecorator.Comparator, self).operate( - op, *other, **kwargs - ) + return super().operate(op, *other, **kwargs) def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any @@ -1683,9 +1681,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): if TYPE_CHECKING: assert isinstance(self.expr.type, TypeDecorator) kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types - return super(TypeDecorator.Comparator, self).reverse_operate( - op, other, **kwargs - ) + return super().reverse_operate(op, other, **kwargs) @property def comparator_factory( # type: ignore # mypy properties bug diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index ec8ea757f..14cbe2456 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -316,8 +316,7 @@ def visit_binary_product( if isinstance(element, ColumnClause): yield element for elem in element.get_children(): - for e in visit(elem): - yield e + yield from visit(elem) list(visit(expr)) visit = None # type: ignore # remove gc cycles @@ -433,12 +432,10 @@ def expand_column_list_from_order_by(collist, order_by): in the collist. """ - cols_already_present = set( - [ - col.element if col._order_by_label_element is not None else col - for col in collist - ] - ) + cols_already_present = { + col.element if col._order_by_label_element is not None else col + for col in collist + } to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by])) @@ -463,13 +460,10 @@ def clause_is_present(clause, search): def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]: if isinstance(clause, Join): - for t in tables_from_leftmost(clause.left): - yield t - for t in tables_from_leftmost(clause.right): - yield t + yield from tables_from_leftmost(clause.left) + yield from tables_from_leftmost(clause.right) elif isinstance(clause, FromGrouping): - for t in tables_from_leftmost(clause.element): - yield t + yield from tables_from_leftmost(clause.element) else: yield clause @@ -592,7 +586,7 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row: "Row[Any]", max_chars: int = 300): + def __init__(self, row: Row[Any], max_chars: int = 300): self.row = row self.max_chars = max_chars @@ -775,7 +769,7 @@ class _repr_params(_repr_base): ) return text - def _repr_param_tuple(self, params: "Sequence[Any]") -> str: + def _repr_param_tuple(self, params: Sequence[Any]) -> str: trunc = self.trunc ( diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 2fda1e9cb..d183372c3 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -269,9 +269,9 @@ class DialectSQL(CompiledSQL): return received_stmt == stmt def _received_statement(self, execute_observed): - received_stmt, received_params = super( - DialectSQL, self - )._received_statement(execute_observed) + received_stmt, received_params = super()._received_statement( + execute_observed + ) # TODO: why do we need this part? for real_stmt in execute_observed.statements: @@ -392,15 +392,15 @@ class EachOf(AssertRule): if self.rules and not self.rules[0].is_consumed: self.rules[0].no_more_statements() elif self.rules: - super(EachOf, self).no_more_statements() + super().no_more_statements() class Conditional(EachOf): def __init__(self, condition, rules, else_rules): if condition: - super(Conditional, self).__init__(*rules) + super().__init__(*rules) else: - super(Conditional, self).__init__(*else_rules) + super().__init__(*else_rules) class Or(AllOf): diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index c083f4e73..0a60a20d3 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -285,21 +285,21 @@ def reconnecting_engine(url=None, options=None): @typing.overload def testing_engine( - url: Optional["URL"] = None, + url: Optional[URL] = None, options: Optional[Dict[str, Any]] = None, asyncio: Literal[False] = False, transfer_staticpool: bool = False, -) -> "Engine": +) -> Engine: ... @typing.overload def testing_engine( - url: Optional["URL"] = None, + url: Optional[URL] = None, options: Optional[Dict[str, Any]] = None, asyncio: Literal[True] = True, transfer_staticpool: bool = False, -) -> "AsyncEngine": +) -> AsyncEngine: ... diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 25c6a0482..3cb060d01 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -129,10 +129,8 @@ class compound: for fail in self.fails: if fail(config): print( - ( - "%s failed as expected (%s): %s " - % (name, fail._as_string(config), ex) - ) + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), ex) ) break else: diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index dcee3f18b..12b5acba4 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -814,7 +814,7 @@ class DeclarativeMappedTest(MappedTest): # sets up cls.Basic which is helpful for things like composite # classes - super(DeclarativeMappedTest, cls)._with_register_classes(fn) + super()._with_register_classes(fn) if cls._tables_metadata.tables and cls.run_create_tables: cls._tables_metadata.create_all(config.db) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 0a70f4008..d590ecbe4 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -49,7 +49,7 @@ def pytest_addoption(parser): required=False, help=None, # noqa ): - super(CallableAction, self).__init__( + super().__init__( option_strings=option_strings, dest=dest, nargs=0, @@ -210,7 +210,7 @@ def pytest_collection_modifyitems(session, config, items): and not item.getparent(pytest.Class).name.startswith("_") ] - test_classes = set(item.getparent(pytest.Class) for item in items) + test_classes = {item.getparent(pytest.Class) for item in items} def collect(element): for inst_or_fn in element.collect(): diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 7672bcde5..dfc3f28f6 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -195,7 +195,7 @@ class ProfileStatsFile: def _read(self): try: profile_f = open(self.fname) - except IOError: + except OSError: return for lineno, line in enumerate(profile_f): line = line.strip() @@ -212,7 +212,7 @@ class ProfileStatsFile: profile_f.close() def _write(self): - print(("Writing profile file %s" % self.fname)) + print("Writing profile file %s" % self.fname) profile_f = open(self.fname, "w") profile_f.write(self._header()) for test_key in sorted(self.data): @@ -293,7 +293,7 @@ def count_functions(variance=0.05): else: line_no, expected_count = expected - print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) + print("Pstats calls: %d Expected %s" % (callcount, expected_count)) stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort)) stats.print_stats() if _profile_stats.dump: diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 33e395c48..01cec1fb0 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -1,4 +1,3 @@ -#! coding: utf-8 # mypy: ignore-errors diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 68d1c13fa..bf745095d 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1871,14 +1871,12 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest): # "unique constraints" are actually unique indexes (with possible # exception of a unique that is a dupe of another one in the case # of Oracle). make sure # they aren't duplicated. - idx_names = set([idx.name for idx in reflected.indexes]) - uq_names = set( - [ - uq.name - for uq in reflected.constraints - if isinstance(uq, sa.UniqueConstraint) - ] - ).difference(["unique_c_a_b"]) + idx_names = {idx.name for idx in reflected.indexes} + uq_names = { + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + }.difference(["unique_c_a_b"]) assert not idx_names.intersection(uq_names) if names_that_duplicate_index: @@ -2519,10 +2517,10 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): ) t.create(connection) eq_( - dict( - (col["name"], col["nullable"]) + { + col["name"]: col["nullable"] for col in inspect(connection).get_columns("t") - ), + }, {"a": True, "b": False}, ) @@ -2613,7 +2611,7 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): # that can reflect these, since alembic looks for this opts = insp.get_foreign_keys("table")[0]["options"] - eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) + eq_({k: opts[k] for k in opts if opts[k]}, {}) opts = insp.get_foreign_keys("user")[0]["options"] eq_(opts, expected) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 838b740fd..6394e4b9a 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -552,7 +552,7 @@ class FetchLimitOffsetTest(fixtures.TablesTest): .offset(2) ).fetchall() eq_(fa[0], (3, 3, 4)) - eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) @testing.requires.fetch_ties @testing.requires.fetch_offset_with_options @@ -623,7 +623,7 @@ class FetchLimitOffsetTest(fixtures.TablesTest): .offset(2) ).fetchall() eq_(fa[0], (3, 3, 4)) - eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) class SameNamedSchemaTableTest(fixtures.TablesTest): diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 25ed041c2..36fd7f247 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -832,8 +832,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): result = {row[0] for row in connection.execute(t.select())} output = set(output) if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) + result = {filter_(x) for x in result} + output = {filter_(x) for x in output} eq_(result, output) if check_scale: eq_([str(x) for x in result], [str(x) for x in output]) @@ -969,13 +969,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.requires.precision_numerics_general def test_precision_decimal(self, do_numeric_test): - numbers = set( - [ - decimal.Decimal("54.234246451650"), - decimal.Decimal("0.004354"), - decimal.Decimal("900.0"), - ] - ) + numbers = { + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + } do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) @@ -988,52 +986,46 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set( - [ - decimal.Decimal("1E-2"), - decimal.Decimal("1E-3"), - decimal.Decimal("1E-4"), - decimal.Decimal("1E-5"), - decimal.Decimal("1E-6"), - decimal.Decimal("1E-7"), - decimal.Decimal("1E-8"), - decimal.Decimal("0.01000005940696"), - decimal.Decimal("0.00000005940696"), - decimal.Decimal("0.00000000000696"), - decimal.Decimal("0.70000000000696"), - decimal.Decimal("696E-12"), - ] - ) + numbers = { + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + } do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large def test_enotation_decimal_large(self, do_numeric_test): """test exceedingly large decimals.""" - numbers = set( - [ - decimal.Decimal("4E+8"), - decimal.Decimal("5748E+15"), - decimal.Decimal("1.521E+15"), - decimal.Decimal("00000000000000.1E+12"), - ] - ) + numbers = { + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + } do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits def test_many_significant_digits(self, do_numeric_test): - numbers = set( - [ - decimal.Decimal("31943874831932418390.01"), - decimal.Decimal("319438950232418390.273596"), - decimal.Decimal("87673.594069654243"), - ] - ) + numbers = { + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + } do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits def test_numeric_no_decimal(self, do_numeric_test): - numbers = set([decimal.Decimal("1.000")]) + numbers = {decimal.Decimal("1.000")} do_numeric_test( Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -1258,7 +1250,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def default(self, o): if isinstance(o, decimal.Decimal): return str(o) - return super(DecimalEncoder, self).default(o) + return super().default(o) json_data = json.dumps(data_element, cls=DecimalEncoder) diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 54be2e4e5..22df74590 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -185,9 +185,7 @@ class Properties(Generic[_T]): return iter(list(self._data.values())) def __dir__(self) -> List[str]: - return dir(super(Properties, self)) + [ - str(k) for k in self._data.keys() - ] + return dir(super()) + [str(k) for k in self._data.keys()] def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: return list(self) + list(other) # type: ignore @@ -477,8 +475,7 @@ def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: elem: _T for elem in x: if not isinstance(elem, str) and hasattr(elem, "__iter__"): - for y in flatten_iterator(elem): - yield y + yield from flatten_iterator(elem) else: yield elem @@ -504,7 +501,7 @@ class LRUCache(typing.MutableMapping[_KT, _VT]): capacity: int threshold: float - size_alert: Optional[Callable[["LRUCache[_KT, _VT]"], None]] + size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] def __init__( self, diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 969e8d92e..ec9463019 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -32,7 +32,7 @@ if typing.TYPE_CHECKING: dead: bool gr_context: Optional[Context] - def __init__(self, fn: Callable[..., Any], driver: "greenlet"): + def __init__(self, fn: Callable[..., Any], driver: greenlet): ... def throw(self, *arg: Any) -> Any: diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 24f9bcf10..6517e381c 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -64,11 +64,11 @@ def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec: if inspect.ismethod(func): func = func.__func__ if not inspect.isfunction(func): - raise TypeError("{!r} is not a Python function".format(func)) + raise TypeError(f"{func!r} is not a Python function") co = func.__code__ if not inspect.iscode(co): - raise TypeError("{!r} is not a code object".format(co)) + raise TypeError(f"{co!r} is not a code object") nargs = co.co_argcount names = co.co_varnames diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 051a8c89e..8df4950a3 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1522,7 +1522,7 @@ class classproperty(property): fget: Callable[[Any], Any] def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any): - super(classproperty, self).__init__(fget, *arg, **kw) + super().__init__(fget, *arg, **kw) self.__doc__ = fget.__doc__ def __get__(self, obj: Any, cls: Optional[type] = None) -> Any: @@ -1793,7 +1793,7 @@ class _hash_limit_string(str): interpolated = (value % args) + ( " (this warning may be suppressed after %d occurrences)" % num ) - self = super(_hash_limit_string, cls).__new__(cls, interpolated) + self = super().__new__(cls, interpolated) self._hash = hash("%s_%d" % (value, hash(interpolated) % num)) return self diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 620e3bbb7..96aa5db2f 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -72,8 +72,7 @@ def sort( """ for set_ in sort_as_subsets(tuples, allitems): - for s in set_: - yield s + yield from set_ def find_cycles( |
