diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/firebird/base.py')
| -rw-r--r-- | lib/sqlalchemy/dialects/firebird/base.py | 569 |
1 files changed, 399 insertions, 170 deletions
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 7b470c189..1e9c778f3 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -79,48 +79,254 @@ from sqlalchemy.engine import base, default, reflection from sqlalchemy.sql import compiler from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC, - SMALLINT, TEXT, TIME, TIMESTAMP, Integer) - - -RESERVED_WORDS = set([ - "active", "add", "admin", "after", "all", "alter", "and", "any", "as", - "asc", "ascending", "at", "auto", "avg", "before", "begin", "between", - "bigint", "bit_length", "blob", "both", "by", "case", "cast", "char", - "character", "character_length", "char_length", "check", "close", - "collate", "column", "commit", "committed", "computed", "conditional", - "connect", "constraint", "containing", "count", "create", "cross", - "cstring", "current", "current_connection", "current_date", - "current_role", "current_time", "current_timestamp", - "current_transaction", "current_user", "cursor", "database", "date", - "day", "dec", "decimal", "declare", "default", "delete", "desc", - "descending", "disconnect", "distinct", "do", "domain", "double", - "drop", "else", "end", "entry_point", "escape", "exception", - "execute", "exists", "exit", "external", "extract", "fetch", "file", - "filter", "float", "for", "foreign", "from", "full", "function", - "gdscode", "generator", "gen_id", "global", "grant", "group", - "having", "hour", "if", "in", "inactive", "index", "inner", - "input_type", "insensitive", "insert", "int", "integer", "into", "is", - "isolation", "join", "key", "leading", "left", "length", "level", - "like", "long", "lower", "manual", "max", "maximum_segment", "merge", - "min", "minute", "module_name", "month", "names", "national", - "natural", "nchar", "no", "not", "null", "numeric", "octet_length", - "of", "on", "only", "open", "option", "or", "order", "outer", - "output_type", "overflow", "page", "pages", "page_size", "parameter", - "password", "plan", "position", "post_event", "precision", "primary", - "privileges", "procedure", "protected", "rdb$db_key", "read", "real", - "record_version", "recreate", "recursive", "references", "release", - "reserv", "reserving", "retain", "returning_values", "returns", - "revoke", "right", "rollback", "rows", "row_count", "savepoint", - "schema", "second", "segment", "select", "sensitive", "set", "shadow", - "shared", "singular", "size", "smallint", "snapshot", "some", "sort", - "sqlcode", "stability", "start", "starting", "starts", "statistics", - "sub_type", "sum", "suspend", "table", "then", "time", "timestamp", - "to", "trailing", "transaction", "trigger", "trim", "uncommitted", - "union", "unique", "update", "upper", "user", "using", "value", - "values", "varchar", "variable", "varying", "view", "wait", "when", - "where", "while", "with", "work", "write", "year", -]) +from sqlalchemy.types import ( + BIGINT, + BLOB, + DATE, + FLOAT, + INTEGER, + NUMERIC, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + Integer, +) + + +RESERVED_WORDS = set( + [ + "active", + "add", + "admin", + "after", + "all", + "alter", + "and", + "any", + "as", + "asc", + "ascending", + "at", + "auto", + "avg", + "before", + "begin", + "between", + "bigint", + "bit_length", + "blob", + "both", + "by", + "case", + "cast", + "char", + "character", + "character_length", + "char_length", + "check", + "close", + "collate", + "column", + "commit", + "committed", + "computed", + "conditional", + "connect", + "constraint", + "containing", + "count", + "create", + "cross", + "cstring", + "current", + "current_connection", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_transaction", + "current_user", + "cursor", + "database", + "date", + "day", + "dec", + "decimal", + "declare", + "default", + "delete", + "desc", + "descending", + "disconnect", + "distinct", + "do", + "domain", + "double", + "drop", + "else", + "end", + "entry_point", + "escape", + "exception", + "execute", + "exists", + "exit", + "external", + "extract", + "fetch", + "file", + "filter", + "float", + "for", + "foreign", + "from", + "full", + "function", + "gdscode", + "generator", + "gen_id", + "global", + "grant", + "group", + "having", + "hour", + "if", + "in", + "inactive", + "index", + "inner", + "input_type", + "insensitive", + "insert", + "int", + "integer", + "into", + "is", + "isolation", + "join", + "key", + "leading", + "left", + "length", + "level", + "like", + "long", + "lower", + "manual", + "max", + "maximum_segment", + "merge", + "min", + "minute", + "module_name", + "month", + "names", + "national", + "natural", + "nchar", + "no", + "not", + "null", + "numeric", + "octet_length", + "of", + "on", + "only", + "open", + "option", + "or", + "order", + "outer", + "output_type", + "overflow", + "page", + "pages", + "page_size", + "parameter", + "password", + "plan", + "position", + "post_event", + "precision", + "primary", + "privileges", + "procedure", + "protected", + "rdb$db_key", + "read", + "real", + "record_version", + "recreate", + "recursive", + "references", + "release", + "reserv", + "reserving", + "retain", + "returning_values", + "returns", + "revoke", + "right", + "rollback", + "rows", + "row_count", + "savepoint", + "schema", + "second", + "segment", + "select", + "sensitive", + "set", + "shadow", + "shared", + "singular", + "size", + "smallint", + "snapshot", + "some", + "sort", + "sqlcode", + "stability", + "start", + "starting", + "starts", + "statistics", + "sub_type", + "sum", + "suspend", + "table", + "then", + "time", + "timestamp", + "to", + "trailing", + "transaction", + "trigger", + "trim", + "uncommitted", + "union", + "unique", + "update", + "upper", + "user", + "using", + "value", + "values", + "varchar", + "variable", + "varying", + "view", + "wait", + "when", + "where", + "while", + "with", + "work", + "write", + "year", + ] +) class _StringType(sqltypes.String): @@ -133,7 +339,8 @@ class _StringType(sqltypes.String): class VARCHAR(_StringType, sqltypes.VARCHAR): """Firebird VARCHAR type""" - __visit_name__ = 'VARCHAR' + + __visit_name__ = "VARCHAR" def __init__(self, length=None, **kwargs): super(VARCHAR, self).__init__(length=length, **kwargs) @@ -141,7 +348,8 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): class CHAR(_StringType, sqltypes.CHAR): """Firebird CHAR type""" - __visit_name__ = 'CHAR' + + __visit_name__ = "CHAR" def __init__(self, length=None, **kwargs): super(CHAR, self).__init__(length=length, **kwargs) @@ -154,32 +362,33 @@ class _FBDateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day) else: return value + return process -colspecs = { - sqltypes.DateTime: _FBDateTime -} + +colspecs = {sqltypes.DateTime: _FBDateTime} ischema_names = { - 'SHORT': SMALLINT, - 'LONG': INTEGER, - 'QUAD': FLOAT, - 'FLOAT': FLOAT, - 'DATE': DATE, - 'TIME': TIME, - 'TEXT': TEXT, - 'INT64': BIGINT, - 'DOUBLE': FLOAT, - 'TIMESTAMP': TIMESTAMP, - 'VARYING': VARCHAR, - 'CSTRING': CHAR, - 'BLOB': BLOB, + "SHORT": SMALLINT, + "LONG": INTEGER, + "QUAD": FLOAT, + "FLOAT": FLOAT, + "DATE": DATE, + "TIME": TIME, + "TEXT": TEXT, + "INT64": BIGINT, + "DOUBLE": FLOAT, + "TIMESTAMP": TIMESTAMP, + "VARYING": VARCHAR, + "CSTRING": CHAR, + "BLOB": BLOB, } # TODO: date conversion types (should be implemented as _FBDateTime, # _FBDate, etc. as bind/result functionality is required) + class FBTypeCompiler(compiler.GenericTypeCompiler): def visit_boolean(self, type_, **kw): return self.visit_SMALLINT(type_, **kw) @@ -194,11 +403,11 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): return "BLOB SUB_TYPE 0" def _extend_string(self, type_, basic): - charset = getattr(type_, 'charset', None) + charset = getattr(type_, "charset", None) if charset is None: return basic else: - return '%s CHARACTER SET %s' % (basic, charset) + return "%s CHARACTER SET %s" % (basic, charset) def visit_CHAR(self, type_, **kw): basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw) @@ -207,8 +416,8 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): def visit_VARCHAR(self, type_, **kw): if not type_.length: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw) return self._extend_string(type_, basic) @@ -228,36 +437,42 @@ class FBCompiler(sql.compiler.SQLCompiler): return "CURRENT_TIMESTAMP" def visit_startswith_op_binary(self, binary, operator, **kw): - return '%s STARTING WITH %s' % ( + return "%s STARTING WITH %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) + binary.right._compiler_dispatch(self, **kw), + ) def visit_notstartswith_op_binary(self, binary, operator, **kw): - return '%s NOT STARTING WITH %s' % ( + return "%s NOT STARTING WITH %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) + binary.right._compiler_dispatch(self, **kw), + ) def visit_mod_binary(self, binary, operator, **kw): return "mod(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_alias(self, alias, asfrom=False, **kwargs): if self.dialect._version_two: - return super(FBCompiler, self).\ - visit_alias(alias, asfrom=asfrom, **kwargs) + return super(FBCompiler, self).visit_alias( + alias, asfrom=asfrom, **kwargs + ) else: # Override to not use the AS keyword which FB 1.5 does not like if asfrom: - alias_name = isinstance(alias.name, - expression._truncated_label) and \ - self._truncated_identifier("alias", - alias.name) or alias.name - - return self.process( - alias.original, asfrom=asfrom, **kwargs) + \ - " " + \ - self.preparer.format_alias(alias, alias_name) + alias_name = ( + isinstance(alias.name, expression._truncated_label) + and self._truncated_identifier("alias", alias.name) + or alias.name + ) + + return ( + self.process(alias.original, asfrom=asfrom, **kwargs) + + " " + + self.preparer.format_alias(alias, alias_name) + ) else: return self.process(alias.original, **kwargs) @@ -320,7 +535,7 @@ class FBCompiler(sql.compiler.SQLCompiler): for c in expression._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) class FBDDLCompiler(sql.compiler.DDLCompiler): @@ -333,27 +548,33 @@ class FBDDLCompiler(sql.compiler.DDLCompiler): # http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html if create.element.start is not None: raise NotImplemented( - "Firebird SEQUENCE doesn't support START WITH") + "Firebird SEQUENCE doesn't support START WITH" + ) if create.element.increment is not None: raise NotImplemented( - "Firebird SEQUENCE doesn't support INCREMENT BY") + "Firebird SEQUENCE doesn't support INCREMENT BY" + ) if self.dialect._version_two: - return "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + return "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) else: - return "CREATE GENERATOR %s" % \ - self.preparer.format_sequence(create.element) + return "CREATE GENERATOR %s" % self.preparer.format_sequence( + create.element + ) def visit_drop_sequence(self, drop): """Generate a ``DROP GENERATOR`` statement for the sequence.""" if self.dialect._version_two: - return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % self.preparer.format_sequence( + drop.element + ) else: - return "DROP GENERATOR %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP GENERATOR %s" % self.preparer.format_sequence( + drop.element + ) class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): @@ -361,7 +582,8 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union( - ['_']) + ["_"] + ) def __init__(self, dialect): super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) @@ -372,16 +594,16 @@ class FBExecutionContext(default.DefaultExecutionContext): """Get the next value from the sequence using ``gen_id()``.""" return self._execute_scalar( - "SELECT gen_id(%s, 1) FROM rdb$database" % - self.dialect.identifier_preparer.format_sequence(seq), - type_ + "SELECT gen_id(%s, 1) FROM rdb$database" + % self.dialect.identifier_preparer.format_sequence(seq), + type_, ) class FBDialect(default.DefaultDialect): """Firebird dialect""" - name = 'firebird' + name = "firebird" max_identifier_length = 31 @@ -413,23 +635,23 @@ class FBDialect(default.DefaultDialect): def initialize(self, connection): super(FBDialect, self).initialize(connection) - self._version_two = ('firebird' in self.server_version_info and - self.server_version_info >= (2, ) - ) or \ - ('interbase' in self.server_version_info and - self.server_version_info >= (6, ) - ) + self._version_two = ( + "firebird" in self.server_version_info + and self.server_version_info >= (2,) + ) or ( + "interbase" in self.server_version_info + and self.server_version_info >= (6,) + ) if not self._version_two: # TODO: whatever other pre < 2.0 stuff goes here self.ischema_names = ischema_names.copy() - self.ischema_names['TIMESTAMP'] = sqltypes.DATE - self.colspecs = { - sqltypes.DateTime: sqltypes.DATE - } + self.ischema_names["TIMESTAMP"] = sqltypes.DATE + self.colspecs = {sqltypes.DateTime: sqltypes.DATE} - self.implicit_returning = self._version_two and \ - self.__dict__.get('implicit_returning', True) + self.implicit_returning = self._version_two and self.__dict__.get( + "implicit_returning", True + ) def normalize_name(self, name): # Remove trailing spaces: FB uses a CHAR() type, @@ -437,8 +659,9 @@ class FBDialect(default.DefaultDialect): name = name and name.rstrip() if name is None: return None - elif name.upper() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + elif name.upper() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.lower() elif name.lower() == name: return quoted_name(name, quote=True) @@ -448,8 +671,9 @@ class FBDialect(default.DefaultDialect): def denormalize_name(self, name): if name is None: return None - elif name.lower() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.upper() else: return name @@ -522,7 +746,7 @@ class FBDialect(default.DefaultDialect): rp = connection.execute(qry, [self.denormalize_name(view_name)]) row = rp.first() if row: - return row['view_source'] + return row["view_source"] else: return None @@ -538,13 +762,13 @@ class FBDialect(default.DefaultDialect): tablename = self.denormalize_name(table_name) # get primary key fields c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) - pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] - return {'constrained_columns': pkfields, 'name': None} + pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()] + return {"constrained_columns": pkfields, "name": None} @reflection.cache - def get_column_sequence(self, connection, - table_name, column_name, - schema=None, **kw): + def get_column_sequence( + self, connection, table_name, column_name, schema=None, **kw + ): tablename = self.denormalize_name(table_name) colname = self.denormalize_name(column_name) # Heuristic-query to determine the generator associated to a PK field @@ -567,7 +791,7 @@ class FBDialect(default.DefaultDialect): """ genr = connection.execute(genqry, [tablename, colname]).first() if genr is not None: - return dict(name=self.normalize_name(genr['fgenerator'])) + return dict(name=self.normalize_name(genr["fgenerator"])) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -595,7 +819,7 @@ class FBDialect(default.DefaultDialect): """ # get the PK, used to determine the eventual associated sequence pk_constraint = self.get_pk_constraint(connection, table_name) - pkey_cols = pk_constraint['constrained_columns'] + pkey_cols = pk_constraint["constrained_columns"] tablename = self.denormalize_name(table_name) # get all of the fields for this table @@ -605,26 +829,28 @@ class FBDialect(default.DefaultDialect): row = c.fetchone() if row is None: break - name = self.normalize_name(row['fname']) - orig_colname = row['fname'] + name = self.normalize_name(row["fname"]) + orig_colname = row["fname"] # get the data type - colspec = row['ftype'].rstrip() + colspec = row["ftype"].rstrip() coltype = self.ischema_names.get(colspec) if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (colspec, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" + % (colspec, name) + ) coltype = sqltypes.NULLTYPE - elif issubclass(coltype, Integer) and row['fprec'] != 0: + elif issubclass(coltype, Integer) and row["fprec"] != 0: coltype = NUMERIC( - precision=row['fprec'], - scale=row['fscale'] * -1) - elif colspec in ('VARYING', 'CSTRING'): - coltype = coltype(row['flen']) - elif colspec == 'TEXT': - coltype = TEXT(row['flen']) - elif colspec == 'BLOB': - if row['stype'] == 1: + precision=row["fprec"], scale=row["fscale"] * -1 + ) + elif colspec in ("VARYING", "CSTRING"): + coltype = coltype(row["flen"]) + elif colspec == "TEXT": + coltype = TEXT(row["flen"]) + elif colspec == "BLOB": + if row["stype"] == 1: coltype = TEXT() else: coltype = BLOB() @@ -633,36 +859,36 @@ class FBDialect(default.DefaultDialect): # does it have a default value? defvalue = None - if row['fdefault'] is not None: + if row["fdefault"] is not None: # the value comes down as "DEFAULT 'value'": there may be # more than one whitespace around the "DEFAULT" keyword # and it may also be lower case # (see also http://tracker.firebirdsql.org/browse/CORE-356) - defexpr = row['fdefault'].lstrip() - assert defexpr[:8].rstrip().upper() == \ - 'DEFAULT', "Unrecognized default value: %s" % \ - defexpr + defexpr = row["fdefault"].lstrip() + assert defexpr[:8].rstrip().upper() == "DEFAULT", ( + "Unrecognized default value: %s" % defexpr + ) defvalue = defexpr[8:].strip() - if defvalue == 'NULL': + if defvalue == "NULL": # Redundant defvalue = None col_d = { - 'name': name, - 'type': coltype, - 'nullable': not bool(row['null_flag']), - 'default': defvalue, - 'autoincrement': 'auto', + "name": name, + "type": coltype, + "nullable": not bool(row["null_flag"]), + "default": defvalue, + "autoincrement": "auto", } if orig_colname.lower() == orig_colname: - col_d['quote'] = True + col_d["quote"] = True # if the PK is a single field, try to see if its linked to # a sequence thru a trigger if len(pkey_cols) == 1 and name == pkey_cols[0]: seq_d = self.get_column_sequence(connection, tablename, name) if seq_d is not None: - col_d['sequence'] = seq_d + col_d["sequence"] = seq_d cols.append(col_d) return cols @@ -689,24 +915,26 @@ class FBDialect(default.DefaultDialect): tablename = self.denormalize_name(table_name) c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) - fks = util.defaultdict(lambda: { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [] - }) + fks = util.defaultdict( + lambda: { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + } + ) for row in c: - cname = self.normalize_name(row['cname']) + cname = self.normalize_name(row["cname"]) fk = fks[cname] - if not fk['name']: - fk['name'] = cname - fk['referred_table'] = self.normalize_name(row['targetrname']) - fk['constrained_columns'].append( - self.normalize_name(row['fname'])) - fk['referred_columns'].append( - self.normalize_name(row['targetfname'])) + if not fk["name"]: + fk["name"] = cname + fk["referred_table"] = self.normalize_name(row["targetrname"]) + fk["constrained_columns"].append(self.normalize_name(row["fname"])) + fk["referred_columns"].append( + self.normalize_name(row["targetfname"]) + ) return list(fks.values()) @reflection.cache @@ -729,13 +957,14 @@ class FBDialect(default.DefaultDialect): indexes = util.defaultdict(dict) for row in c: - indexrec = indexes[row['index_name']] - if 'name' not in indexrec: - indexrec['name'] = self.normalize_name(row['index_name']) - indexrec['column_names'] = [] - indexrec['unique'] = bool(row['unique_flag']) - - indexrec['column_names'].append( - self.normalize_name(row['field_name'])) + indexrec = indexes[row["index_name"]] + if "name" not in indexrec: + indexrec["name"] = self.normalize_name(row["index_name"]) + indexrec["column_names"] = [] + indexrec["unique"] = bool(row["unique_flag"]) + + indexrec["column_names"].append( + self.normalize_name(row["field_name"]) + ) return list(indexes.values()) |
