diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/access/base.py')
| -rw-r--r-- | lib/sqlalchemy/dialects/access/base.py | 94 |
1 files changed, 65 insertions, 29 deletions
diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py index 7d9270877..8efea5941 100644 --- a/lib/sqlalchemy/dialects/access/base.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -100,9 +100,11 @@ class AcTimeStamp(types.TIMESTAMP): class AccessExecutionContext(default.DefaultExecutionContext): def _has_implicit_sequence(self, column): if column.primary_key and column.autoincrement: - if isinstance(column.type, types.Integer) and not column.foreign_keys: - if column.default is None or (isinstance(column.default, schema.Sequence) and \ - column.default.optional): + if isinstance(column.type, types.Integer) and \ + not column.foreign_keys: + if column.default is None or \ + (isinstance(column.default, schema.Sequence) and \ + column.default.optional): return True return False @@ -114,17 +116,20 @@ class AccessExecutionContext(default.DefaultExecutionContext): if not hasattr(tbl, 'has_sequence'): tbl.has_sequence = None for column in tbl.c: - if getattr(column, 'sequence', False) or self._has_implicit_sequence(column): + if getattr(column, 'sequence', False) or \ + self._has_implicit_sequence(column): tbl.has_sequence = column break if bool(tbl.has_sequence): # TBD: for some reason _last_inserted_ids doesn't exist here # (but it does at corresponding point in mssql???) - #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + #if not len(self._last_inserted_ids) or + # self._last_inserted_ids[0] is None: self.cursor.execute("SELECT @@identity AS lastrowid") row = self.cursor.fetchone() - self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:] + self._last_inserted_ids = [int(row[0])] + #+ self._last_inserted_ids[1:] # print "LAST ROW ID", self._last_inserted_ids super(AccessExecutionContext, self).post_exec() @@ -162,6 +167,7 @@ class AccessDialect(default.DefaultDialect): self.text_as_varchar = False self._dtbs = None + @classmethod def dbapi(cls): import win32com.client, pythoncom @@ -170,16 +176,19 @@ class AccessDialect(default.DefaultDialect): const = win32com.client.constants for suffix in (".36", ".35", ".30"): try: - daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix) + daoEngine = win32com.client.\ + gencache.\ + EnsureDispatch("DAO.DBEngine" + suffix) break except pythoncom.com_error: pass else: - raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.") + raise exc.InvalidRequestError( + "Can't find a DB engine. Check " + "http://support.microsoft.com/kb/239114 for details.") import pyodbc as module return module - dbapi = classmethod(dbapi) def create_connect_args(self, url): opts = url.translate_connect_args() @@ -197,7 +206,8 @@ class AccessDialect(default.DefaultDialect): def do_execute(self, cursor, statement, params, context=None): if params == {}: params = () - super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs) + super(AccessDialect, self).\ + do_execute(cursor, statement, params, **kwargs) def _execute(self, c, statement, parameters): try: @@ -230,7 +240,8 @@ class AccessDialect(default.DefaultDialect): const.dbLongBinary: AcBinary, const.dbMemo: AcText, const.dbBoolean: AcBoolean, - const.dbText: AcUnicode, # All Access strings are unicode + const.dbText: AcUnicode, # All Access strings are + # unicode const.dbCurrency: AcNumeric, } @@ -252,7 +263,8 @@ class AccessDialect(default.DefaultDialect): colargs = \ { - 'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField), + 'nullable': not(col.Required or + col.Attributes & const.dbAutoIncrField), } default = col.DefaultValue @@ -261,9 +273,11 @@ class AccessDialect(default.DefaultDialect): elif default: if col.Type == const.dbBoolean: default = default == 'Yes' and '1' or '0' - colargs['server_default'] = schema.DefaultClause(sql.text(default)) + colargs['server_default'] = \ + schema.DefaultClause(sql.text(default)) - table.append_column(schema.Column(col.Name, coltype, **colargs)) + table.append_column( + schema.Column(col.Name, coltype, **colargs)) # TBD: check constraints @@ -274,7 +288,11 @@ class AccessDialect(default.DefaultDialect): thecol = table.c[col.Name] table.primary_key.add(thecol) if isinstance(thecol.type, AcInteger) and \ - not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)): + not (thecol.default and + isinstance( + thecol.default.arg, + schema.Sequence + )): thecol.autoincrement = False # Then add other indexes @@ -294,7 +312,9 @@ class AccessDialect(default.DefaultDialect): continue scols = [c.ForeignName for c in fk.Fields] rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields] - table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True)) + table.append_constraint( + schema.ForeignKeyConstraint(scols, rcols,\ + link_to_name=True)) finally: dtbs.Close() @@ -305,7 +325,8 @@ class AccessDialect(default.DefaultDialect): # This is necessary, so we get the latest updates dtbs = daoEngine.OpenDatabase(connection.engine.url.database) - names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"] + names = [t.Name for t in dtbs.TableDefs + if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"] dtbs.Close() return names @@ -331,7 +352,8 @@ class AccessCompiler(compiler.SQLCompiler): if select.limit: s += "TOP %s " % (select.limit) if select.offset: - raise exc.InvalidRequestError('Access does not support LIMIT with an offset') + raise exc.InvalidRequestError( + 'Access does not support LIMIT with an offset') return s def limit_clause(self, select): @@ -346,14 +368,16 @@ class AccessCompiler(compiler.SQLCompiler): if isinstance(column, expression.Function): return column.label() else: - return super(AccessCompiler, self).label_select_column(select, column, asfrom) + return super(AccessCompiler, self).\ + label_select_column(select, column, asfrom) function_rewrites = {'current_date': 'now', 'current_timestamp': 'now', 'length': 'len', } def visit_function(self, func): - """Access function names differ from the ANSI SQL names; rewrite common ones""" + """Access function names differ from the ANSI SQL names; + rewrite common ones""" func.name = self.function_rewrites.get(func.name, func.name) return super(AccessCompiler, self).visit_function(func) @@ -369,21 +393,30 @@ class AccessCompiler(compiler.SQLCompiler): return "" def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ - self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + return (self.process(join.left, asfrom=True) + \ + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ + self.process(join.right, asfrom=True) + " ON " + \ + self.process(join.onclause)) def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) + return 'DATEPART("%s", %s)' % \ + (field, self.process(extract.expr, **kw)) class AccessDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() + colspec = self.preparer.format_column(column) + " " + \ + column.type.dialect_impl(self.dialect).get_col_spec() # install a sequence if we have an implicit IDENTITY column - if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ - column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys: - if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + if (not getattr(column.table, 'has_sequence', False)) and \ + column.primary_key and \ + column.autoincrement and \ + isinstance(column.type, types.Integer) and \ + not column.foreign_keys: + if column.default is None or \ + (isinstance(column.default, schema.Sequence) and + column.default.optional): column.sequence = schema.Sequence(column.name + '_seq') if not column.nullable: @@ -401,13 +434,16 @@ class AccessDDLCompiler(compiler.DDLCompiler): def visit_drop_index(self, drop): index = drop.element - self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False))) + self.append("\nDROP INDEX [%s].[%s]" % \ + (index.table.name, + self._validate_identifier(index.name, False))) class AccessIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = compiler.RESERVED_WORDS.copy() reserved_words.update(['value', 'text']) def __init__(self, dialect): - super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + super(AccessIdentifierPreparer, self).\ + __init__(dialect, initial_quote='[', final_quote=']') dialect = AccessDialect |
