diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/sqlite/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 908 |
1 files changed, 565 insertions, 343 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c487af898..cb9389af1 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -579,9 +579,20 @@ from ... import util from ...engine import default, reflection from ...sql import compiler -from ...types import (BLOB, BOOLEAN, CHAR, DECIMAL, FLOAT, - INTEGER, REAL, NUMERIC, SMALLINT, TEXT, - TIMESTAMP, VARCHAR) +from ...types import ( + BLOB, + BOOLEAN, + CHAR, + DECIMAL, + FLOAT, + INTEGER, + REAL, + NUMERIC, + SMALLINT, + TEXT, + TIMESTAMP, + VARCHAR, +) from .json import JSON, JSONIndexType, JSONPathType @@ -610,10 +621,15 @@ class _DateTimeMixin(object): """ spec = self._storage_format % { - "year": 0, "month": 0, "day": 0, "hour": 0, - "minute": 0, "second": 0, "microsecond": 0 + "year": 0, + "month": 0, + "day": 0, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, } - return bool(re.search(r'[^0-9]', spec)) + return bool(re.search(r"[^0-9]", spec)) def adapt(self, cls, **kw): if issubclass(cls, _DateTimeMixin): @@ -628,6 +644,7 @@ class _DateTimeMixin(object): def process(value): return "'%s'" % bp(value) + return process @@ -671,13 +688,17 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): ) def __init__(self, *args, **kwargs): - truncate_microseconds = kwargs.pop('truncate_microseconds', False) + truncate_microseconds = kwargs.pop("truncate_microseconds", False) super(DATETIME, self).__init__(*args, **kwargs) if truncate_microseconds: - assert 'storage_format' not in kwargs, "You can specify only "\ + assert "storage_format" not in kwargs, ( + "You can specify only " "one of truncate_microseconds or storage_format." - assert 'regexp' not in kwargs, "You can specify only one of "\ + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " "truncate_microseconds or regexp." + ) self._storage_format = ( "%(year)04d-%(month)02d-%(day)02d " "%(hour)02d:%(minute)02d:%(second)02d" @@ -693,33 +714,37 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): return None elif isinstance(value, datetime_datetime): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, - 'hour': value.hour, - 'minute': value.minute, - 'second': value.second, - 'microsecond': value.microsecond, + "year": value.year, + "month": value.month, + "day": value.day, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, } elif isinstance(value, datetime_date): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, - 'hour': 0, - 'minute': 0, - 'second': 0, - 'microsecond': 0, + "year": value.year, + "month": value.month, + "day": value.day, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, } else: - raise TypeError("SQLite DateTime type only accepts Python " - "datetime and date objects as input.") + raise TypeError( + "SQLite DateTime type only accepts Python " + "datetime and date objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.datetime) + self._reg, datetime.datetime + ) else: return processors.str_to_datetime @@ -768,19 +793,23 @@ class DATE(_DateTimeMixin, sqltypes.Date): return None elif isinstance(value, datetime_date): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, + "year": value.year, + "month": value.month, + "day": value.day, } else: - raise TypeError("SQLite Date type only accepts Python " - "date objects as input.") + raise TypeError( + "SQLite Date type only accepts Python " + "date objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.date) + self._reg, datetime.date + ) else: return processors.str_to_date @@ -820,13 +849,17 @@ class TIME(_DateTimeMixin, sqltypes.Time): _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" def __init__(self, *args, **kwargs): - truncate_microseconds = kwargs.pop('truncate_microseconds', False) + truncate_microseconds = kwargs.pop("truncate_microseconds", False) super(TIME, self).__init__(*args, **kwargs) if truncate_microseconds: - assert 'storage_format' not in kwargs, "You can specify only "\ + assert "storage_format" not in kwargs, ( + "You can specify only " "one of truncate_microseconds or storage_format." - assert 'regexp' not in kwargs, "You can specify only one of "\ + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " "truncate_microseconds or regexp." + ) self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d" def bind_processor(self, dialect): @@ -838,23 +871,28 @@ class TIME(_DateTimeMixin, sqltypes.Time): return None elif isinstance(value, datetime_time): return format % { - 'hour': value.hour, - 'minute': value.minute, - 'second': value.second, - 'microsecond': value.microsecond, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, } else: - raise TypeError("SQLite Time type only accepts Python " - "time objects as input.") + raise TypeError( + "SQLite Time type only accepts Python " + "time objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.time) + self._reg, datetime.time + ) else: return processors.str_to_time + colspecs = { sqltypes.Date: DATE, sqltypes.DateTime: DATETIME, @@ -865,31 +903,31 @@ colspecs = { } ischema_names = { - 'BIGINT': sqltypes.BIGINT, - 'BLOB': sqltypes.BLOB, - 'BOOL': sqltypes.BOOLEAN, - 'BOOLEAN': sqltypes.BOOLEAN, - 'CHAR': sqltypes.CHAR, - 'DATE': sqltypes.DATE, - 'DATE_CHAR': sqltypes.DATE, - 'DATETIME': sqltypes.DATETIME, - 'DATETIME_CHAR': sqltypes.DATETIME, - 'DOUBLE': sqltypes.FLOAT, - 'DECIMAL': sqltypes.DECIMAL, - 'FLOAT': sqltypes.FLOAT, - 'INT': sqltypes.INTEGER, - 'INTEGER': sqltypes.INTEGER, - 'JSON': JSON, - 'NUMERIC': sqltypes.NUMERIC, - 'REAL': sqltypes.REAL, - 'SMALLINT': sqltypes.SMALLINT, - 'TEXT': sqltypes.TEXT, - 'TIME': sqltypes.TIME, - 'TIME_CHAR': sqltypes.TIME, - 'TIMESTAMP': sqltypes.TIMESTAMP, - 'VARCHAR': sqltypes.VARCHAR, - 'NVARCHAR': sqltypes.NVARCHAR, - 'NCHAR': sqltypes.NCHAR, + "BIGINT": sqltypes.BIGINT, + "BLOB": sqltypes.BLOB, + "BOOL": sqltypes.BOOLEAN, + "BOOLEAN": sqltypes.BOOLEAN, + "CHAR": sqltypes.CHAR, + "DATE": sqltypes.DATE, + "DATE_CHAR": sqltypes.DATE, + "DATETIME": sqltypes.DATETIME, + "DATETIME_CHAR": sqltypes.DATETIME, + "DOUBLE": sqltypes.FLOAT, + "DECIMAL": sqltypes.DECIMAL, + "FLOAT": sqltypes.FLOAT, + "INT": sqltypes.INTEGER, + "INTEGER": sqltypes.INTEGER, + "JSON": JSON, + "NUMERIC": sqltypes.NUMERIC, + "REAL": sqltypes.REAL, + "SMALLINT": sqltypes.SMALLINT, + "TEXT": sqltypes.TEXT, + "TIME": sqltypes.TIME, + "TIME_CHAR": sqltypes.TIME, + "TIMESTAMP": sqltypes.TIMESTAMP, + "VARCHAR": sqltypes.VARCHAR, + "NVARCHAR": sqltypes.NVARCHAR, + "NCHAR": sqltypes.NCHAR, } @@ -897,17 +935,18 @@ class SQLiteCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'month': '%m', - 'day': '%d', - 'year': '%Y', - 'second': '%S', - 'hour': '%H', - 'doy': '%j', - 'minute': '%M', - 'epoch': '%s', - 'dow': '%w', - 'week': '%W', - }) + "month": "%m", + "day": "%d", + "year": "%Y", + "second": "%S", + "hour": "%H", + "doy": "%j", + "minute": "%M", + "epoch": "%s", + "dow": "%w", + "week": "%W", + }, + ) def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" @@ -916,10 +955,10 @@ class SQLiteCompiler(compiler.SQLCompiler): return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def visit_char_length_func(self, fn, **kw): return "length%s" % self.function_argspec(fn) @@ -934,11 +973,12 @@ class SQLiteCompiler(compiler.SQLCompiler): try: return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( self.extract_map[extract.field], - self.process(extract.expr, **kw) + self.process(extract.expr, **kw), ) except KeyError: raise exc.CompileError( - "%s is not a valid extract argument." % extract.field) + "%s is not a valid extract argument." % extract.field + ) def limit_clause(self, select, **kw): text = "" @@ -954,35 +994,41 @@ class SQLiteCompiler(compiler.SQLCompiler): def for_update_clause(self, select, **kw): # sqlite has no "FOR UPDATE" AFAICT - return '' + return "" def visit_is_distinct_from_binary(self, binary, operator, **kw): - return "%s IS NOT %s" % (self.process(binary.left), - self.process(binary.right)) + return "%s IS NOT %s" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_isnot_distinct_from_binary(self, binary, operator, **kw): - return "%s IS %s" % (self.process(binary.left), - self.process(binary.right)) + return "%s IS %s" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_QUOTE(JSON_EXTRACT(%s, %s))" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_QUOTE(JSON_EXTRACT(%s, %s))" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_empty_set_expr(self, type_): - return 'SELECT 1 FROM (SELECT 1) WHERE 1!=1' + return "SELECT 1 FROM (SELECT 1) WHERE 1!=1" class SQLiteDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): coltype = self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: @@ -991,29 +1037,33 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if not column.nullable: colspec += " NOT NULL" - on_conflict_clause = column.dialect_options['sqlite'][ - 'on_conflict_not_null'] + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_not_null" + ] if on_conflict_clause is not None: colspec += " ON CONFLICT " + on_conflict_clause if column.primary_key: if ( - column.autoincrement is True and - len(column.table.primary_key.columns) != 1 + column.autoincrement is True + and len(column.table.primary_key.columns) != 1 ): raise exc.CompileError( "SQLite does not support autoincrement for " - "composite primary keys") + "composite primary keys" + ) - if (column.table.dialect_options['sqlite']['autoincrement'] and - len(column.table.primary_key.columns) == 1 and - issubclass( - column.type._type_affinity, sqltypes.Integer) and - not column.foreign_keys): + if ( + column.table.dialect_options["sqlite"]["autoincrement"] + and len(column.table.primary_key.columns) == 1 + and issubclass(column.type._type_affinity, sqltypes.Integer) + and not column.foreign_keys + ): colspec += " PRIMARY KEY" - on_conflict_clause = column.dialect_options['sqlite'][ - 'on_conflict_primary_key'] + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_primary_key" + ] if on_conflict_clause is not None: colspec += " ON CONFLICT " + on_conflict_clause @@ -1027,21 +1077,25 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): # with the column itself. if len(constraint.columns) == 1: c = list(constraint)[0] - if (c.primary_key and - c.table.dialect_options['sqlite']['autoincrement'] and - issubclass(c.type._type_affinity, sqltypes.Integer) and - not c.foreign_keys): + if ( + c.primary_key + and c.table.dialect_options["sqlite"]["autoincrement"] + and issubclass(c.type._type_affinity, sqltypes.Integer) + and not c.foreign_keys + ): return None - text = super( - SQLiteDDLCompiler, - self).visit_primary_key_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is None and len(constraint.columns) == 1: - on_conflict_clause = list(constraint)[0].\ - dialect_options['sqlite']['on_conflict_primary_key'] + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_primary_key" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1049,15 +1103,17 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_unique_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_unique_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_unique_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is None and len(constraint.columns) == 1: - on_conflict_clause = list(constraint)[0].\ - dialect_options['sqlite']['on_conflict_unique'] + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_unique" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1065,12 +1121,13 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_check_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_check_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_check_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1078,14 +1135,15 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_column_check_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_column_check_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_column_check_constraint( + constraint + ) - if constraint.dialect_options['sqlite']['on_conflict'] is not None: + if constraint.dialect_options["sqlite"]["on_conflict"] is not None: raise exc.CompileError( "SQLite does not support on conflict clause for " - "column check constraint") + "column check constraint" + ) return text @@ -1097,40 +1155,40 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if local_table.schema != remote_table.schema: return None else: - return super( - SQLiteDDLCompiler, - self).visit_foreign_key_constraint(constraint) + return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint( + constraint + ) def define_constraint_remote_table(self, constraint, table, preparer): """Format the remote table clause of a CREATE CONSTRAINT clause.""" return preparer.format_table(table, use_schema=False) - def visit_create_index(self, create, include_schema=False, - include_table_schema=True): + def visit_create_index( + self, create, include_schema=False, include_table_schema=True + ): index = create.element self._verify_index_table(index) preparer = self.preparer text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=True), - preparer.format_table(index.table, - use_schema=False), - ', '.join( - self.sql_compiler.process( - expr, include_table=False, literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=False), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) whereclause = index.dialect_options["sqlite"]["where"] if whereclause is not None: where_compiled = self.sql_compiler.process( - whereclause, include_table=False, - literal_binds=True) + whereclause, include_table=False, literal_binds=True + ) text += " WHERE " + where_compiled return text @@ -1141,22 +1199,28 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): return self.visit_BLOB(type_) def visit_DATETIME(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) else: return "DATETIME_CHAR" def visit_DATE(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_DATE(type_) else: return "DATE_CHAR" def visit_TIME(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_TIME(type_) else: return "TIME_CHAR" @@ -1169,33 +1233,135 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set([ - 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', - 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', - 'conflict', 'constraint', 'create', 'cross', 'current_date', - 'current_time', 'current_timestamp', 'database', 'default', - 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', - 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', - 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', - 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', - 'into', 'is', 'isnull', 'join', 'key', 'left', 'like', 'limit', - 'match', 'natural', 'not', 'notnull', 'null', 'of', 'offset', 'on', - 'or', 'order', 'outer', 'plan', 'pragma', 'primary', 'query', - 'raise', 'references', 'reindex', 'rename', 'replace', 'restrict', - 'right', 'rollback', 'row', 'select', 'set', 'table', 'temp', - 'temporary', 'then', 'to', 'transaction', 'trigger', 'true', 'union', - 'unique', 'update', 'using', 'vacuum', 'values', 'view', 'virtual', - 'when', 'where', - ]) + reserved_words = set( + [ + "add", + "after", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "attach", + "autoincrement", + "before", + "begin", + "between", + "by", + "cascade", + "case", + "cast", + "check", + "collate", + "column", + "commit", + "conflict", + "constraint", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "database", + "default", + "deferrable", + "deferred", + "delete", + "desc", + "detach", + "distinct", + "drop", + "each", + "else", + "end", + "escape", + "except", + "exclusive", + "explain", + "false", + "fail", + "for", + "foreign", + "from", + "full", + "glob", + "group", + "having", + "if", + "ignore", + "immediate", + "in", + "index", + "indexed", + "initially", + "inner", + "insert", + "instead", + "intersect", + "into", + "is", + "isnull", + "join", + "key", + "left", + "like", + "limit", + "match", + "natural", + "not", + "notnull", + "null", + "of", + "offset", + "on", + "or", + "order", + "outer", + "plan", + "pragma", + "primary", + "query", + "raise", + "references", + "reindex", + "rename", + "replace", + "restrict", + "right", + "rollback", + "row", + "select", + "set", + "table", + "temp", + "temporary", + "then", + "to", + "transaction", + "trigger", + "true", + "union", + "unique", + "update", + "using", + "vacuum", + "values", + "view", + "virtual", + "when", + "where", + ] + ) class SQLiteExecutionContext(default.DefaultExecutionContext): @util.memoized_property def _preserve_raw_colnames(self): - return not self.dialect._broken_dotted_colnames or \ - self.execution_options.get("sqlite_raw_colnames", False) + return ( + not self.dialect._broken_dotted_colnames + or self.execution_options.get("sqlite_raw_colnames", False) + ) def _translate_colname(self, colname): # TODO: detect SQLite version 3.10.0 or greater; @@ -1212,7 +1378,7 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): class SQLiteDialect(default.DefaultDialect): - name = 'sqlite' + name = "sqlite" supports_alter = False supports_unicode_statements = True supports_unicode_binds = True @@ -1221,7 +1387,7 @@ class SQLiteDialect(default.DefaultDialect): supports_cast = True supports_multivalues_insert = True - default_paramstyle = 'qmark' + default_paramstyle = "qmark" execution_ctx_cls = SQLiteExecutionContext statement_compiler = SQLiteCompiler ddl_compiler = SQLiteDDLCompiler @@ -1235,27 +1401,30 @@ class SQLiteDialect(default.DefaultDialect): supports_default_values = True construct_arguments = [ - (sa_schema.Table, { - "autoincrement": False - }), - (sa_schema.Index, { - "where": None, - }), - (sa_schema.Column, { - "on_conflict_primary_key": None, - "on_conflict_not_null": None, - "on_conflict_unique": None, - }), - (sa_schema.Constraint, { - "on_conflict": None, - }), + (sa_schema.Table, {"autoincrement": False}), + (sa_schema.Index, {"where": None}), + ( + sa_schema.Column, + { + "on_conflict_primary_key": None, + "on_conflict_not_null": None, + "on_conflict_unique": None, + }, + ), + (sa_schema.Constraint, {"on_conflict": None}), ] _broken_fk_pragma_quotes = False _broken_dotted_colnames = False - def __init__(self, isolation_level=None, native_datetime=False, - _json_serializer=None, _json_deserializer=None, **kwargs): + def __init__( + self, + isolation_level=None, + native_datetime=False, + _json_serializer=None, + _json_deserializer=None, + **kwargs + ): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_serializer = _json_serializer @@ -1269,35 +1438,42 @@ class SQLiteDialect(default.DefaultDialect): if self.dbapi is not None: self.supports_right_nested_joins = ( - self.dbapi.sqlite_version_info >= (3, 7, 16)) - self._broken_dotted_colnames = ( - self.dbapi.sqlite_version_info < (3, 10, 0) + self.dbapi.sqlite_version_info >= (3, 7, 16) + ) + self._broken_dotted_colnames = self.dbapi.sqlite_version_info < ( + 3, + 10, + 0, + ) + self.supports_default_values = self.dbapi.sqlite_version_info >= ( + 3, + 3, + 8, ) - self.supports_default_values = ( - self.dbapi.sqlite_version_info >= (3, 3, 8)) - self.supports_cast = ( - self.dbapi.sqlite_version_info >= (3, 2, 3)) + self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3) self.supports_multivalues_insert = ( # http://www.sqlite.org/releaselog/3_7_11.html - self.dbapi.sqlite_version_info >= (3, 7, 11)) + self.dbapi.sqlite_version_info + >= (3, 7, 11) + ) # see http://www.sqlalchemy.org/trac/ticket/2568 # as well as http://www.sqlite.org/src/info/600482d161 - self._broken_fk_pragma_quotes = ( - self.dbapi.sqlite_version_info < (3, 6, 14)) + self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < ( + 3, + 6, + 14, + ) - _isolation_lookup = { - 'READ UNCOMMITTED': 1, - 'SERIALIZABLE': 0, - } + _isolation_lookup = {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0} def set_isolation_level(self, connection, level): try: - isolation_level = self._isolation_lookup[level.replace('_', ' ')] + isolation_level = self._isolation_lookup[level.replace("_", " ")] except KeyError: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) @@ -1305,7 +1481,7 @@ class SQLiteDialect(default.DefaultDialect): def get_isolation_level(self, connection): cursor = connection.cursor() - cursor.execute('PRAGMA read_uncommitted') + cursor.execute("PRAGMA read_uncommitted") res = cursor.fetchone() if res: value = res[0] @@ -1327,8 +1503,10 @@ class SQLiteDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None @@ -1344,44 +1522,51 @@ class SQLiteDialect(default.DefaultDialect): def get_table_names(self, connection, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema + master = "%s.sqlite_master" % qschema else: master = "sqlite_master" - s = ("SELECT name FROM %s " - "WHERE type='table' ORDER BY name") % (master,) + s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % ( + master, + ) rs = connection.execute(s) return [row[0] for row in rs] @reflection.cache def get_temp_table_names(self, connection, **kw): - s = "SELECT name FROM sqlite_temp_master "\ + s = ( + "SELECT name FROM sqlite_temp_master " "WHERE type='table' ORDER BY name " + ) rs = connection.execute(s) return [row[0] for row in rs] @reflection.cache def get_temp_view_names(self, connection, **kw): - s = "SELECT name FROM sqlite_temp_master "\ + s = ( + "SELECT name FROM sqlite_temp_master " "WHERE type='view' ORDER BY name " + ) rs = connection.execute(s) return [row[0] for row in rs] def has_table(self, connection, table_name, schema=None): info = self._get_table_pragma( - connection, "table_info", table_name, schema=schema) + connection, "table_info", table_name, schema=schema + ) return bool(info) @reflection.cache def get_view_names(self, connection, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema + master = "%s.sqlite_master" % qschema else: master = "sqlite_master" - s = ("SELECT name FROM %s " - "WHERE type='view' ORDER BY name") % (master,) + s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % ( + master, + ) rs = connection.execute(s) return [row[0] for row in rs] @@ -1390,21 +1575,27 @@ class SQLiteDialect(default.DefaultDialect): def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema - s = ("SELECT sql FROM %s WHERE name = '%s'" - "AND type='view'") % (master, view_name) + master = "%s.sqlite_master" % qschema + s = ("SELECT sql FROM %s WHERE name = '%s'" "AND type='view'") % ( + master, + view_name, + ) rs = connection.execute(s) else: try: - s = ("SELECT sql FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE name = '%s' " - "AND type='view'") % view_name + s = ( + "SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type='view'" + ) % view_name rs = connection.execute(s) except exc.DBAPIError: - s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " - "AND type='view'") % view_name + s = ( + "SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type='view'" + ) % view_name rs = connection.execute(s) result = rs.fetchall() @@ -1414,15 +1605,24 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): info = self._get_table_pragma( - connection, "table_info", table_name, schema=schema) + connection, "table_info", table_name, schema=schema + ) columns = [] for row in info: (name, type_, nullable, default, primary_key) = ( - row[1], row[2].upper(), not row[3], row[4], row[5]) + row[1], + row[2].upper(), + not row[3], + row[4], + row[5], + ) - columns.append(self._get_column_info(name, type_, nullable, - default, primary_key)) + columns.append( + self._get_column_info( + name, type_, nullable, default, primary_key + ) + ) return columns def _get_column_info(self, name, type_, nullable, default, primary_key): @@ -1432,12 +1632,12 @@ class SQLiteDialect(default.DefaultDialect): default = util.text_type(default) return { - 'name': name, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': 'auto', - 'primary_key': primary_key, + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": "auto", + "primary_key": primary_key, } def _resolve_type_affinity(self, type_): @@ -1457,36 +1657,37 @@ class SQLiteDialect(default.DefaultDialect): DATE and DOUBLE). """ - match = re.match(r'([\w ]+)(\(.*?\))?', type_) + match = re.match(r"([\w ]+)(\(.*?\))?", type_) if match: coltype = match.group(1) args = match.group(2) else: - coltype = '' - args = '' + coltype = "" + args = "" if coltype in self.ischema_names: coltype = self.ischema_names[coltype] - elif 'INT' in coltype: + elif "INT" in coltype: coltype = sqltypes.INTEGER - elif 'CHAR' in coltype or 'CLOB' in coltype or 'TEXT' in coltype: + elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype: coltype = sqltypes.TEXT - elif 'BLOB' in coltype or not coltype: + elif "BLOB" in coltype or not coltype: coltype = sqltypes.NullType - elif 'REAL' in coltype or 'FLOA' in coltype or 'DOUB' in coltype: + elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype: coltype = sqltypes.REAL else: coltype = sqltypes.NUMERIC if args is not None: - args = re.findall(r'(\d+)', args) + args = re.findall(r"(\d+)", args) try: coltype = coltype(*[int(a) for a in args]) except TypeError: util.warn( "Could not instantiate type %s with " - "reflected arguments %s; using no arguments." % - (coltype, args)) + "reflected arguments %s; using no arguments." + % (coltype, args) + ) coltype = coltype() else: coltype = coltype() @@ -1498,58 +1699,59 @@ class SQLiteDialect(default.DefaultDialect): constraint_name = None table_data = self._get_table_sql(connection, table_name, schema=schema) if table_data: - PK_PATTERN = r'CONSTRAINT (\w+) PRIMARY KEY' + PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY" result = re.search(PK_PATTERN, table_data, re.I) constraint_name = result.group(1) if result else None cols = self.get_columns(connection, table_name, schema, **kw) pkeys = [] for col in cols: - if col['primary_key']: - pkeys.append(col['name']) + if col["primary_key"]: + pkeys.append(col["name"]) - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): # sqlite makes this *extremely difficult*. # First, use the pragma to get the actual FKs. pragma_fks = self._get_table_pragma( - connection, "foreign_key_list", - table_name, schema=schema + connection, "foreign_key_list", table_name, schema=schema ) fks = {} for row in pragma_fks: - (numerical_id, rtbl, lcol, rcol) = ( - row[0], row[2], row[3], row[4]) + (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) if rcol is None: rcol = lcol if self._broken_fk_pragma_quotes: - rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) + rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl) if numerical_id in fks: fk = fks[numerical_id] else: fk = fks[numerical_id] = { - 'name': None, - 'constrained_columns': [], - 'referred_schema': schema, - 'referred_table': rtbl, - 'referred_columns': [], - 'options': {} + "name": None, + "constrained_columns": [], + "referred_schema": schema, + "referred_table": rtbl, + "referred_columns": [], + "options": {}, } fks[numerical_id] = fk - fk['constrained_columns'].append(lcol) - fk['referred_columns'].append(rcol) + fk["constrained_columns"].append(lcol) + fk["referred_columns"].append(rcol) def fk_sig(constrained_columns, referred_table, referred_columns): - return tuple(constrained_columns) + (referred_table,) + \ - tuple(referred_columns) + return ( + tuple(constrained_columns) + + (referred_table,) + + tuple(referred_columns) + ) # then, parse the actual SQL and attempt to find DDL that matches # the names as well. SQLite saves the DDL in whatever format @@ -1558,10 +1760,13 @@ class SQLiteDialect(default.DefaultDialect): keys_by_signature = dict( ( fk_sig( - fk['constrained_columns'], - fk['referred_table'], fk['referred_columns']), - fk - ) for fk in fks.values() + fk["constrained_columns"], + fk["referred_table"], + fk["referred_columns"], + ), + fk, + ) + for fk in fks.values() ) table_data = self._get_table_sql(connection, table_name, schema=schema) @@ -1571,55 +1776,66 @@ class SQLiteDialect(default.DefaultDialect): def parse_fks(): FK_PATTERN = ( - r'(?:CONSTRAINT (\w+) +)?' - r'FOREIGN KEY *\( *(.+?) *\) +' + r"(?:CONSTRAINT (\w+) +)?" + r"FOREIGN KEY *\( *(.+?) *\) +" r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\) *' - r'((?:ON (?:DELETE|UPDATE) ' - r'(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)' + r"((?:ON (?:DELETE|UPDATE) " + r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)" ) for match in re.finditer(FK_PATTERN, table_data, re.I): ( - constraint_name, constrained_columns, - referred_quoted_name, referred_name, - referred_columns, onupdatedelete) = \ - match.group(1, 2, 3, 4, 5, 6) + constraint_name, + constrained_columns, + referred_quoted_name, + referred_name, + referred_columns, + onupdatedelete, + ) = match.group(1, 2, 3, 4, 5, 6) constrained_columns = list( - self._find_cols_in_sig(constrained_columns)) + self._find_cols_in_sig(constrained_columns) + ) if not referred_columns: referred_columns = constrained_columns else: referred_columns = list( - self._find_cols_in_sig(referred_columns)) + self._find_cols_in_sig(referred_columns) + ) referred_name = referred_quoted_name or referred_name options = {} for token in re.split(r" *\bON\b *", onupdatedelete.upper()): if token.startswith("DELETE"): - options['ondelete'] = token[6:].strip() + options["ondelete"] = token[6:].strip() elif token.startswith("UPDATE"): options["onupdate"] = token[6:].strip() yield ( - constraint_name, constrained_columns, - referred_name, referred_columns, options) + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) + fkeys = [] for ( - constraint_name, constrained_columns, - referred_name, referred_columns, options) in parse_fks(): - sig = fk_sig( - constrained_columns, referred_name, referred_columns) + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) in parse_fks(): + sig = fk_sig(constrained_columns, referred_name, referred_columns) if sig not in keys_by_signature: util.warn( "WARNING: SQL-parsed foreign key constraint " "'%s' could not be located in PRAGMA " - "foreign_keys for table %s" % ( - sig, - table_name - )) + "foreign_keys for table %s" % (sig, table_name) + ) continue key = keys_by_signature.pop(sig) - key['name'] = constraint_name - key['options'] = options + key["name"] = constraint_name + key["options"] = options fkeys.append(key) # assume the remainders are the unnamed, inline constraints, just # use them as is as it's extremely difficult to parse inline @@ -1632,20 +1848,26 @@ class SQLiteDialect(default.DefaultDialect): yield match.group(1) or match.group(2) @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): auto_index_by_sig = {} for idx in self.get_indexes( - connection, table_name, schema=schema, - include_auto_indexes=True, **kw): - if not idx['name'].startswith("sqlite_autoindex"): + connection, + table_name, + schema=schema, + include_auto_indexes=True, + **kw + ): + if not idx["name"].startswith("sqlite_autoindex"): continue - sig = tuple(idx['column_names']) + sig = tuple(idx["column_names"]) auto_index_by_sig[sig] = idx table_data = self._get_table_sql( - connection, table_name, schema=schema, **kw) + connection, table_name, schema=schema, **kw + ) if not table_data: return [] @@ -1654,8 +1876,8 @@ class SQLiteDialect(default.DefaultDialect): def parse_uqs(): UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( - r'(?:(".+?")|([a-z0-9]+)) ' - r'+[a-z0-9_ ]+? +UNIQUE') + r'(?:(".+?")|([a-z0-9]+)) ' r"+[a-z0-9_ ]+? +UNIQUE" + ) for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): name, cols = match.group(1, 2) @@ -1666,34 +1888,29 @@ class SQLiteDialect(default.DefaultDialect): # are kind of the same thing :) for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): cols = list( - self._find_cols_in_sig(match.group(1) or match.group(2))) + self._find_cols_in_sig(match.group(1) or match.group(2)) + ) yield None, cols for name, cols in parse_uqs(): sig = tuple(cols) if sig in auto_index_by_sig: auto_index_by_sig.pop(sig) - parsed_constraint = { - 'name': name, - 'column_names': cols - } + parsed_constraint = {"name": name, "column_names": cols} unique_constraints.append(parsed_constraint) # NOTE: auto_index_by_sig might not be empty here, # the PRIMARY KEY may have an entry. return unique_constraints @reflection.cache - def get_check_constraints(self, connection, table_name, - schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): table_data = self._get_table_sql( - connection, table_name, schema=schema, **kw) + connection, table_name, schema=schema, **kw + ) if not table_data: return [] - CHECK_PATTERN = ( - r'(?:CONSTRAINT (\w+) +)?' - r'CHECK *\( *(.+) *\),? *' - ) + CHECK_PATTERN = r"(?:CONSTRAINT (\w+) +)?" r"CHECK *\( *(.+) *\),? *" check_constraints = [] # NOTE: we aren't using re.S here because we actually are # taking advantage of each CHECK constraint being all on one @@ -1701,25 +1918,26 @@ class SQLiteDialect(default.DefaultDialect): # necessarily makes assumptions as to how the CREATE TABLE # was emitted. for match in re.finditer(CHECK_PATTERN, table_data, re.I): - check_constraints.append({ - 'sqltext': match.group(2), - 'name': match.group(1) - }) + check_constraints.append( + {"sqltext": match.group(2), "name": match.group(1)} + ) return check_constraints @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): pragma_indexes = self._get_table_pragma( - connection, "index_list", table_name, schema=schema) + connection, "index_list", table_name, schema=schema + ) indexes = [] - include_auto_indexes = kw.pop('include_auto_indexes', False) + include_auto_indexes = kw.pop("include_auto_indexes", False) for row in pragma_indexes: # ignore implicit primary key index. # http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html - if (not include_auto_indexes and - row[1].startswith('sqlite_autoindex')): + if not include_auto_indexes and row[1].startswith( + "sqlite_autoindex" + ): continue indexes.append(dict(name=row[1], column_names=[], unique=row[2])) @@ -1727,34 +1945,38 @@ class SQLiteDialect(default.DefaultDialect): # loop thru unique indexes to get the column names. for idx in indexes: pragma_index = self._get_table_pragma( - connection, "index_info", idx['name']) + connection, "index_info", idx["name"] + ) for row in pragma_index: - idx['column_names'].append(row[2]) + idx["column_names"].append(row[2]) return indexes @reflection.cache def _get_table_sql(self, connection, table_name, schema=None, **kw): if schema: schema_expr = "%s." % ( - self.identifier_preparer.quote_identifier(schema)) + self.identifier_preparer.quote_identifier(schema) + ) else: schema_expr = "" try: - s = ("SELECT sql FROM " - " (SELECT * FROM %(schema)ssqlite_master UNION ALL " - " SELECT * FROM %(schema)ssqlite_temp_master) " - "WHERE name = '%(table)s' " - "AND type = 'table'" % { - "schema": schema_expr, - "table": table_name}) + s = ( + "SELECT sql FROM " + " (SELECT * FROM %(schema)ssqlite_master UNION ALL " + " SELECT * FROM %(schema)ssqlite_temp_master) " + "WHERE name = '%(table)s' " + "AND type = 'table'" + % {"schema": schema_expr, "table": table_name} + ) rs = connection.execute(s) except exc.DBAPIError: - s = ("SELECT sql FROM %(schema)ssqlite_master " - "WHERE name = '%(table)s' " - "AND type = 'table'" % { - "schema": schema_expr, - "table": table_name}) + s = ( + "SELECT sql FROM %(schema)ssqlite_master " + "WHERE name = '%(table)s' " + "AND type = 'table'" + % {"schema": schema_expr, "table": table_name} + ) rs = connection.execute(s) return rs.scalar() |