diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
| commit | 8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch) | |
| tree | ae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/dialects | |
| parent | 7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff) | |
| download | sqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz | |
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/dialects')
43 files changed, 11179 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py new file mode 100644 index 000000000..91ca91faf --- /dev/null +++ b/lib/sqlalchemy/dialects/__init__.py @@ -0,0 +1,12 @@ +__all__ = ( +# 'access', +# 'firebird', +# 'informix', +# 'maxdb', +# 'mssql', + 'mysql', + 'oracle', + 'postgresql', + 'sqlite', +# 'sybase', + ) diff --git a/lib/sqlalchemy/dialects/access/__init__.py b/lib/sqlalchemy/dialects/access/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/dialects/access/__init__.py diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py new file mode 100644 index 000000000..ed8297137 --- /dev/null +++ b/lib/sqlalchemy/dialects/access/base.py @@ -0,0 +1,442 @@ +# access.py +# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk +# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Support for the Microsoft Access database. + +This dialect is *not* tested on SQLAlchemy 0.6. + + +""" +from sqlalchemy import sql, schema, types, exc, pool +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base + + +class AcNumeric(types.Numeric): + def result_processor(self, dialect): + return None + + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process + + def get_col_spec(self): + return "NUMERIC" + +class AcFloat(types.Float): + def get_col_spec(self): + return "FLOAT" + + def bind_processor(self, dialect): + """By converting to string, we can use Decimal types round-trip.""" + def process(value): + if not value is None: + return str(value) + return None + return process + +class AcInteger(types.Integer): + def get_col_spec(self): + return "INTEGER" + +class AcTinyInteger(types.Integer): + def get_col_spec(self): + return "TINYINT" + +class AcSmallInteger(types.SmallInteger): + def get_col_spec(self): + return "SMALLINT" + +class AcDateTime(types.DateTime): + def __init__(self, *a, **kw): + super(AcDateTime, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class AcDate(types.Date): + def __init__(self, *a, **kw): + super(AcDate, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class AcText(types.Text): + def get_col_spec(self): + return "MEMO" + +class AcString(types.String): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + +class AcUnicode(types.Unicode): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + return None + +class AcChar(types.CHAR): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + +class AcBinary(types.Binary): + def get_col_spec(self): + return "BINARY" + +class AcBoolean(types.Boolean): + def get_col_spec(self): + return "YESNO" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class AcTimeStamp(types.TIMESTAMP): + def get_col_spec(self): + return "TIMESTAMP" + +class AccessExecutionContext(default.DefaultExecutionContext): + def _has_implicit_sequence(self, column): + if column.primary_key and column.autoincrement: + if isinstance(column.type, types.Integer) and not column.foreign_keys: + if column.default is None or (isinstance(column.default, schema.Sequence) and \ + column.default.optional): + return True + return False + + def post_exec(self): + """If we inserted into a row with a COUNTER column, fetch the ID""" + + if self.compiled.isinsert: + tbl = self.compiled.statement.table + if not hasattr(tbl, 'has_sequence'): + tbl.has_sequence = None + for column in tbl.c: + if getattr(column, 'sequence', False) or self._has_implicit_sequence(column): + tbl.has_sequence = column + break + + if bool(tbl.has_sequence): + # TBD: for some reason _last_inserted_ids doesn't exist here + # (but it does at corresponding point in mssql???) + #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:] + # print "LAST ROW ID", self._last_inserted_ids + + super(AccessExecutionContext, self).post_exec() + + +const, daoEngine = None, None +class AccessDialect(default.DefaultDialect): + colspecs = { + types.Unicode : AcUnicode, + types.Integer : AcInteger, + types.SmallInteger: AcSmallInteger, + types.Numeric : AcNumeric, + types.Float : AcFloat, + types.DateTime : AcDateTime, + types.Date : AcDate, + types.String : AcString, + types.Binary : AcBinary, + types.Boolean : AcBoolean, + types.Text : AcText, + types.CHAR: AcChar, + types.TIMESTAMP: AcTimeStamp, + } + name = 'access' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + def type_descriptor(self, typeobj): + newobj = types.adapt_type(typeobj, self.colspecs) + return newobj + + def __init__(self, **params): + super(AccessDialect, self).__init__(**params) + self.text_as_varchar = False + self._dtbs = None + + def dbapi(cls): + import win32com.client, pythoncom + + global const, daoEngine + if const is None: + const = win32com.client.constants + for suffix in (".36", ".35", ".30"): + try: + daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix) + break + except pythoncom.com_error: + pass + else: + raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.") + + import pyodbc as module + return module + dbapi = classmethod(dbapi) + + def create_connect_args(self, url): + opts = url.translate_connect_args() + connectors = ["Driver={Microsoft Access Driver (*.mdb)}"] + connectors.append("Dbq=%s" % opts["database"]) + user = opts.get("username", None) + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % opts.get("password", "")) + return [[";".join(connectors)], {}] + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def do_execute(self, cursor, statement, params, **kwargs): + if params == {}: + params = () + super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs) + + def _execute(self, c, statement, parameters): + try: + if parameters == {}: + parameters = () + c.execute(statement, parameters) + self.context.rowcount = c.rowcount + except Exception, e: + raise exc.DBAPIError.instance(statement, parameters, e) + + def has_table(self, connection, tablename, schema=None): + # This approach seems to be more reliable that using DAO + try: + connection.execute('select top 1 * from [%s]' % tablename) + return True + except Exception, e: + return False + + def reflecttable(self, connection, table, include_columns): + # This is defined in the function, as it relies on win32com constants, + # that aren't imported until dbapi method is called + if not hasattr(self, 'ischema_names'): + self.ischema_names = { + const.dbByte: AcBinary, + const.dbInteger: AcInteger, + const.dbLong: AcInteger, + const.dbSingle: AcFloat, + const.dbDouble: AcFloat, + const.dbDate: AcDateTime, + const.dbLongBinary: AcBinary, + const.dbMemo: AcText, + const.dbBoolean: AcBoolean, + const.dbText: AcUnicode, # All Access strings are unicode + const.dbCurrency: AcNumeric, + } + + # A fresh DAO connection is opened for each reflection + # This is necessary, so we get the latest updates + dtbs = daoEngine.OpenDatabase(connection.engine.url.database) + + try: + for tbl in dtbs.TableDefs: + if tbl.Name.lower() == table.name.lower(): + break + else: + raise exc.NoSuchTableError(table.name) + + for col in tbl.Fields: + coltype = self.ischema_names[col.Type] + if col.Type == const.dbText: + coltype = coltype(col.Size) + + colargs = \ + { + 'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField), + } + default = col.DefaultValue + + if col.Attributes & const.dbAutoIncrField: + colargs['default'] = schema.Sequence(col.Name + '_seq') + elif default: + if col.Type == const.dbBoolean: + default = default == 'Yes' and '1' or '0' + colargs['server_default'] = schema.DefaultClause(sql.text(default)) + + table.append_column(schema.Column(col.Name, coltype, **colargs)) + + # TBD: check constraints + + # Find primary key columns first + for idx in tbl.Indexes: + if idx.Primary: + for col in idx.Fields: + thecol = table.c[col.Name] + table.primary_key.add(thecol) + if isinstance(thecol.type, AcInteger) and \ + not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)): + thecol.autoincrement = False + + # Then add other indexes + for idx in tbl.Indexes: + if not idx.Primary: + if len(idx.Fields) == 1: + col = table.c[idx.Fields[0].Name] + if not col.primary_key: + col.index = True + col.unique = idx.Unique + else: + pass # TBD: multi-column indexes + + + for fk in dtbs.Relations: + if fk.ForeignTable != table.name: + continue + scols = [c.ForeignName for c in fk.Fields] + rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields] + table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True)) + + finally: + dtbs.Close() + + def table_names(self, connection, schema): + # A fresh DAO connection is opened for each reflection + # This is necessary, so we get the latest updates + dtbs = daoEngine.OpenDatabase(connection.engine.url.database) + + names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"] + dtbs.Close() + return names + + +class AccessCompiler(compiler.SQLCompiler): + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'month': 'm', + 'day': 'd', + 'year': 'yyyy', + 'second': 's', + 'hour': 'h', + 'doy': 'y', + 'minute': 'n', + 'quarter': 'q', + 'dow': 'w', + 'week': 'ww' + }) + + def visit_select_precolumns(self, select): + """Access puts TOP, it's version of LIMIT here """ + s = select.distinct and "DISTINCT " or "" + if select.limit: + s += "TOP %s " % (select.limit) + if select.offset: + raise exc.InvalidRequestError('Access does not support LIMIT with an offset') + return s + + def limit_clause(self, select): + """Limit in access is after the select keyword""" + return "" + + def binary_operator_string(self, binary): + """Access uses "mod" instead of "%" """ + return binary.operator == '%' and 'mod' or binary.operator + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label() + else: + return super(AccessCompiler, self).label_select_column(select, column, asfrom) + + function_rewrites = {'current_date': 'now', + 'current_timestamp': 'now', + 'length': 'len', + } + def visit_function(self, func): + """Access function names differ from the ANSI SQL names; rewrite common ones""" + func.name = self.function_rewrites.get(func.name, func.name) + return super(AccessCompiler, self).visit_function(func) + + def for_update_clause(self, select): + """FOR UPDATE is not supported by Access; silently ignore""" + return '' + + # Strip schema + def visit_table(self, table, asfrom=False, **kwargs): + if asfrom: + return self.preparer.quote(table.name, table.quote) + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ + self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + +class AccessDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() + + # install a sequence if we have an implicit IDENTITY column + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys: + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if not column.nullable: + colspec += " NOT NULL" + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + colspec = self.preparer.format_column(column) + " counter" + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + index = drop.element + self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False))) + +class AccessIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = compiler.RESERVED_WORDS.copy() + reserved_words.update(['value', 'text']) + def __init__(self, dialect): + super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + +dialect = AccessDialect +dialect.poolclass = pool.SingletonThreadPool +dialect.statement_compiler = AccessCompiler +dialect.ddlcompiler = AccessDDLCompiler +dialect.preparer = AccessIdentifierPreparer +dialect.execution_ctx_cls = AccessExecutionContext
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py new file mode 100644 index 000000000..6b1b80db2 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.firebird import base, kinterbasdb + +base.dialect = kinterbasdb.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py new file mode 100644 index 000000000..57b89ed05 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -0,0 +1,626 @@ +# firebird.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Support for the Firebird database. + +Connectivity is usually supplied via the kinterbasdb_ +DBAPI module. + +Firebird dialects +----------------- + +Firebird offers two distinct dialects_ (not to be confused with a +SQLAlchemy ``Dialect``): + +dialect 1 + This is the old syntax and behaviour, inherited from Interbase pre-6.0. + +dialect 3 + This is the newer and supported syntax, introduced in Interbase 6.0. + +The SQLAlchemy Firebird dialect detects these versions and +adjusts its representation of SQL accordingly. However, +support for dialect 1 is not well tested and probably has +incompatibilities. + +Firebird Locking Behavior +------------------------- + +Firebird locks tables aggressively. For this reason, a DROP TABLE may +hang until other transactions are released. SQLAlchemy does its best +to release transactions as quickly as possible. The most common cause +of hanging transactions is a non-fully consumed result set, i.e.:: + + result = engine.execute("select * from table") + row = result.fetchone() + return + +Where above, the ``ResultProxy`` has not been fully consumed. The +connection will be returned to the pool and the transactional state +rolled back once the Python garbage collector reclaims the objects +which hold onto the connection, which often occurs asynchronously. +The above use case can be alleviated by calling ``first()`` on the +``ResultProxy`` which will fetch the first row and immediately close +all remaining cursor/connection resources. + +RETURNING support +----------------- + +Firebird 2.0 supports returning a result set from inserts, and 2.1 extends +that to deletes and updates. + +To use this pass the column/expression list to the ``firebird_returning`` +parameter when creating the queries:: + + raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), + firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall() + + +.. [#] Well, that is not the whole story, as the client may still ask + a different (lower) dialect... + +.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html +.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb + +""" + + +import datetime, decimal, re + +from sqlalchemy import schema as sa_schema +from sqlalchemy import exc, types as sqltypes, sql, util +from sqlalchemy.sql import expression +from sqlalchemy.engine import base, default, reflection +from sqlalchemy.sql import compiler + +from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE, + FLOAT, INTEGER, NUMERIC, SMALLINT, + TEXT, TIME, TIMESTAMP, VARCHAR) + + +RESERVED_WORDS = set( + ["action", "active", "add", "admin", "after", "all", "alter", "and", "any", + "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename", + "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer", + "by", "cache", "cascade", "case", "cast", "char", "character", "character_length", + "char_length", "check", "check_point_len", "check_point_length", "close", "collate", + "collation", "column", "commit", "committed", "compiletime", "computed", "conditional", + "connect", "constraint", "containing", "continue", "count", "create", "cstring", + "current", "current_connection", "current_date", "current_role", "current_time", + "current_timestamp", "current_transaction", "current_user", "cursor", "database", + "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete", + "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct", + "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point", + "escape", "event", "exception", "execute", "exists", "exit", "extern", "external", + "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it", + "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto", + "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour", + "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input", + "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join", + "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile", + "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment", + "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month", + "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric", + "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option", + "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength", + "pages", "page_size", "parameter", "password", "plan", "position", "post_event", + "precision", "prepare", "primary", "privileges", "procedure", "protected", "public", + "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate", + "references", "release", "release", "reserv", "reserving", "restrict", "retain", + "return", "returning_values", "returns", "revoke", "right", "role", "rollback", + "row_count", "runtime", "savepoint", "schema", "second", "segment", "select", + "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint", + "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability", + "starting", "starts", "statement", "static", "statistics", "sub_type", "sum", + "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction", + "translate", "translation", "trigger", "trim", "type", "uncommitted", "union", + "unique", "update", "upper", "user", "using", "value", "values", "varchar", + "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when", + "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) + + +class _FBBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + + +colspecs = { + sqltypes.Boolean: _FBBoolean, +} + +ischema_names = { + 'SHORT': SMALLINT, + 'LONG': BIGINT, + 'QUAD': FLOAT, + 'FLOAT': FLOAT, + 'DATE': DATE, + 'TIME': TIME, + 'TEXT': TEXT, + 'INT64': NUMERIC, + 'DOUBLE': FLOAT, + 'TIMESTAMP': TIMESTAMP, + 'VARYING': VARCHAR, + 'CSTRING': CHAR, + 'BLOB': BLOB, + } + + +# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc. +# as bind/result functionality is required) + +class FBTypeCompiler(compiler.GenericTypeCompiler): + def visit_boolean(self, type_): + return self.visit_SMALLINT(type_) + + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_TEXT(self, type_): + return "BLOB SUB_TYPE 1" + + def visit_BLOB(self, type_): + return "BLOB SUB_TYPE 0" + + +class FBCompiler(sql.compiler.SQLCompiler): + """Firebird specific idiosincrasies""" + + def visit_mod(self, binary, **kw): + # Firebird lacks a builtin modulo operator, but there is + # an equivalent function in the ib_udf library. + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_alias(self, alias, asfrom=False, **kwargs): + if self.dialect._version_two: + return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs) + else: + # Override to not use the AS keyword which FB 1.5 does not like + if asfrom: + alias_name = isinstance(alias.name, expression._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \ + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0]) + start = self.process(func.clauses.clauses[1]) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2]) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def visit_length_func(self, function, **kw): + if self.dialect._version_two: + return "char_length" + self.function_argspec(function) + else: + return "strlen" + self.function_argspec(function) + + visit_char_length_func = visit_length_func + + def function_argspec(self, func, **kw): + if func.clauses: + return self.process(func.clause_expr) + else: + return "" + + def default_from(self): + return " FROM rdb$database" + + def visit_sequence(self, seq): + return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) + + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just + before column list Firebird puts the limit and offset right + after the ``SELECT``... + """ + + result = "" + if select._limit: + result += "FIRST %d " % select._limit + if select._offset: + result +="SKIP %d " % select._offset + if select._distinct: + result += "DISTINCT " + return result + + def limit_clause(self, select): + """Already taken care of in the `get_select_precolumns` method.""" + + return "" + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self.process( + self.label_select_column(None, c, asfrom=False), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) + ] + return 'RETURNING ' + ', '.join(columns) + + +class FBDDLCompiler(sql.compiler.DDLCompiler): + """Firebird syntactic idiosincrasies""" + + def visit_create_sequence(self, create): + """Generate a ``CREATE GENERATOR`` statement for the sequence.""" + + if self.dialect._version_two: + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + else: + return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + """Generate a ``DROP GENERATOR`` statement for the sequence.""" + + if self.dialect._version_two: + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + else: + return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element) + + +class FBDefaultRunner(base.DefaultRunner): + """Firebird specific idiosincrasies""" + + def visit_sequence(self, seq): + """Get the next value from the sequence using ``gen_id()``.""" + + return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \ + self.dialect.identifier_preparer.format_sequence(seq)) + + +class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): + """Install Firebird specific reserved words.""" + + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) + + +class FBDialect(default.DefaultDialect): + """Firebird dialect""" + + name = 'firebird' + + max_identifier_length = 31 + + supports_sequences = True + sequences_optional = False + supports_default_values = True + postfetch_lastrowid = False + + requires_name_normalize = True + supports_empty_insert = False + + statement_compiler = FBCompiler + ddl_compiler = FBDDLCompiler + defaultrunner = FBDefaultRunner + preparer = FBIdentifierPreparer + type_compiler = FBTypeCompiler + + colspecs = colspecs + ischema_names = ischema_names + + # defaults to dialect ver. 3, + # will be autodetected off upon + # first connect + _version_two = True + + def initialize(self, connection): + super(FBDialect, self).initialize(connection) + self._version_two = self.server_version_info > (2, ) + if not self._version_two: + # TODO: whatever other pre < 2.0 stuff goes here + self.ischema_names = ischema_names.copy() + self.ischema_names['TIMESTAMP'] = sqltypes.DATE + self.colspecs = { + sqltypes.DateTime: sqltypes.DATE + } + else: + self.implicit_returning = True + + def normalize_name(self, name): + # Remove trailing spaces: FB uses a CHAR() type, + # that is padded with spaces + name = name and name.rstrip() + if name is None: + return None + elif name.upper() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.lower() + else: + return name + + def denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.upper() + else: + return name + + def has_table(self, connection, table_name, schema=None): + """Return ``True`` if the given table exists, ignoring the `schema`.""" + + tblqry = """ + SELECT 1 FROM rdb$database + WHERE EXISTS (SELECT rdb$relation_name + FROM rdb$relations + WHERE rdb$relation_name=?) + """ + c = connection.execute(tblqry, [self.denormalize_name(table_name)]) + return c.first() is not None + + def has_sequence(self, connection, sequence_name): + """Return ``True`` if the given sequence (generator) exists.""" + + genqry = """ + SELECT 1 FROM rdb$database + WHERE EXISTS (SELECT rdb$generator_name + FROM rdb$generators + WHERE rdb$generator_name=?) + """ + c = connection.execute(genqry, [self.denormalize_name(sequence_name)]) + return c.first() is not None + + def table_names(self, connection, schema): + s = """ + SELECT DISTINCT rdb$relation_name + FROM rdb$relation_fields + WHERE rdb$system_flag=0 AND rdb$view_context IS NULL + """ + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + return self.table_names(connection, schema) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + s = """ + SELECT distinct rdb$view_name + FROM rdb$view_relations + """ + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + qry = """ + SELECT rdb$view_source AS view_source + FROM rdb$relations + WHERE rdb$relation_name=? + """ + rp = connection.execute(qry, [self.denormalize_name(view_name)]) + row = rp.first() + if row: + return row['view_source'] + else: + return None + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the PK/FK constrained fields of the given table + keyqry = """ + SELECT se.rdb$field_name AS fname + FROM rdb$relation_constraints rc + JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + """ + tablename = self.denormalize_name(table_name) + # get primary key fields + c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) + pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] + return pkfields + + @reflection.cache + def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw): + tablename = self.denormalize_name(table_name) + colname = self.denormalize_name(column_name) + # Heuristic-query to determine the generator associated to a PK field + genqry = """ + SELECT trigdep.rdb$depended_on_name AS fgenerator + FROM rdb$dependencies tabdep + JOIN rdb$dependencies trigdep + ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name + AND trigdep.rdb$depended_on_type=14 + AND trigdep.rdb$dependent_type=2 + JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name + WHERE tabdep.rdb$depended_on_name=? + AND tabdep.rdb$depended_on_type=0 + AND trig.rdb$trigger_type=1 + AND tabdep.rdb$field_name=? + AND (SELECT count(*) + FROM rdb$dependencies trigdep2 + WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 + """ + genc = connection.execute(genqry, [tablename, colname]) + genr = genc.fetchone() + if genr is not None: + return dict(name=self.normalize_name(genr['fgenerator'])) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + # Query to extract the details of all the fields of the given table + tblqry = """ + SELECT DISTINCT r.rdb$field_name AS fname, + r.rdb$null_flag AS null_flag, + t.rdb$type_name AS ftype, + f.rdb$field_sub_type AS stype, + f.rdb$field_length AS flen, + f.rdb$field_precision AS fprec, + f.rdb$field_scale AS fscale, + COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault + FROM rdb$relation_fields r + JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name + JOIN rdb$types t + ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE' + WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? + ORDER BY r.rdb$field_position + """ + # get the PK, used to determine the eventual associated sequence + pkey_cols = self.get_primary_keys(connection, table_name) + + tablename = self.denormalize_name(table_name) + # get all of the fields for this table + c = connection.execute(tblqry, [tablename]) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + name = self.normalize_name(row['fname']) + # get the data type + + colspec = row['ftype'].rstrip() + coltype = self.ischema_names.get(colspec) + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % + (colspec, name)) + coltype = sqltypes.NULLTYPE + elif colspec == 'INT64': + coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1) + elif colspec in ('VARYING', 'CSTRING'): + coltype = coltype(row['flen']) + elif colspec == 'TEXT': + coltype = TEXT(row['flen']) + elif colspec == 'BLOB': + if row['stype'] == 1: + coltype = TEXT() + else: + coltype = BLOB() + else: + coltype = coltype(row) + + # does it have a default value? + defvalue = None + if row['fdefault'] is not None: + # the value comes down as "DEFAULT 'value'" + assert row['fdefault'].upper().startswith('DEFAULT '), row + defvalue = row['fdefault'][8:] + col_d = { + 'name' : name, + 'type' : coltype, + 'nullable' : not bool(row['null_flag']), + 'default' : defvalue + } + + # if the PK is a single field, try to see if its linked to + # a sequence thru a trigger + if len(pkey_cols)==1 and name==pkey_cols[0]: + seq_d = self.get_column_sequence(connection, tablename, name) + if seq_d is not None: + col_d['sequence'] = seq_d + + cols.append(col_d) + return cols + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the details of each UK/FK of the given table + fkqry = """ + SELECT rc.rdb$constraint_name AS cname, + cse.rdb$field_name AS fname, + ix2.rdb$relation_name AS targetrname, + se.rdb$field_name AS targetfname + FROM rdb$relation_constraints rc + JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name + JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key + JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name + JOIN rdb$index_segments se + ON se.rdb$index_name=ix2.rdb$index_name + AND se.rdb$field_position=cse.rdb$field_position + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + ORDER BY se.rdb$index_name, se.rdb$field_position + """ + tablename = self.denormalize_name(table_name) + + c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) + fks = util.defaultdict(lambda:{ + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + }) + + for row in c: + cname = self.normalize_name(row['cname']) + fk = fks[cname] + if not fk['name']: + fk['name'] = cname + fk['referred_table'] = self.normalize_name(row['targetrname']) + fk['constrained_columns'].append(self.normalize_name(row['fname'])) + fk['referred_columns'].append( + self.normalize_name(row['targetfname'])) + return fks.values() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + qry = """ + SELECT ix.rdb$index_name AS index_name, + ix.rdb$unique_flag AS unique_flag, + ic.rdb$field_name AS field_name + FROM rdb$indices ix + JOIN rdb$index_segments ic + ON ix.rdb$index_name=ic.rdb$index_name + LEFT OUTER JOIN rdb$relation_constraints + ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name + WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL + AND rdb$relation_constraints.rdb$constraint_type IS NULL + ORDER BY index_name, field_name + """ + c = connection.execute(qry, [self.denormalize_name(table_name)]) + + indexes = util.defaultdict(dict) + for row in c: + indexrec = indexes[row['index_name']] + if 'name' not in indexrec: + indexrec['name'] = self.normalize_name(row['index_name']) + indexrec['column_names'] = [] + indexrec['unique'] = bool(row['unique_flag']) + + indexrec['column_names'].append(self.normalize_name(row['field_name'])) + + return indexes.values() + + def do_execute(self, cursor, statement, parameters, **kwargs): + # kinterbase does not accept a None, but wants an empty list + # when there are no arguments. + cursor.execute(statement, parameters or []) + + def do_rollback(self, connection): + # Use the retaining feature, that keeps the transaction going + connection.rollback(True) + + def do_commit(self, connection): + # Use the retaining feature, that keeps the transaction going + connection.commit(True) diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py new file mode 100644 index 000000000..7d30f87f5 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -0,0 +1,70 @@ +from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler +from sqlalchemy.engine.default import DefaultExecutionContext + +_initialized_kb = False + +class Firebird_kinterbasdb(FBDialect): + driver = 'kinterbasdb' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + def __init__(self, type_conv=200, concurrency_level=1, **kwargs): + super(Firebird_kinterbasdb, self).__init__(**kwargs) + + self.type_conv = type_conv + self.concurrency_level = concurrency_level + + @classmethod + def dbapi(cls): + k = __import__('kinterbasdb') + return k + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if opts.get('port'): + opts['host'] = "%s/%s" % (opts['host'], opts['port']) + del opts['port'] + opts.update(url.query) + + type_conv = opts.pop('type_conv', self.type_conv) + concurrency_level = opts.pop('concurrency_level', self.concurrency_level) + global _initialized_kb + if not _initialized_kb and self.dbapi is not None: + _initialized_kb = True + self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. This is more than enough for our purposes, + # as this is mainly (only?) used by the testsuite. + + from re import match + + fbconn = connection.connection + version = fbconn.server_version + m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) + if not m: + raise AssertionError("Could not determine version from string '%s'" % version) + return tuple([int(x) for x in m.group(5, 6, 4)]) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'Unable to complete network request to host' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + msg = str(e) + return ('Invalid connection state' in msg or + 'Invalid cursor state' in msg) + else: + return False + +dialect = Firebird_kinterbasdb diff --git a/lib/sqlalchemy/dialects/informix/__init__.py b/lib/sqlalchemy/dialects/informix/__init__.py new file mode 100644 index 000000000..f2fcc76d4 --- /dev/null +++ b/lib/sqlalchemy/dialects/informix/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.informix import base, informixdb + +base.dialect = informixdb.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/informix/base.py b/lib/sqlalchemy/dialects/informix/base.py new file mode 100644 index 000000000..b69748fcf --- /dev/null +++ b/lib/sqlalchemy/dialects/informix/base.py @@ -0,0 +1,334 @@ +# informix.py +# Copyright (C) 2005,2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# coding: gbk +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the Informix database. + +This dialect is *not* tested on SQLAlchemy 0.6. + +""" + + +import datetime + +from sqlalchemy import sql, schema, exc, pool, util +from sqlalchemy.sql import compiler +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes + + +class InfoDateTime(sqltypes.DateTime ): + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process + +class InfoTime(sqltypes.Time ): + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process + + def result_processor(self, dialect): + def process(value): + if isinstance( value , datetime.datetime ): + return value.time() + else: + return value + return process + + +class InfoBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +colspecs = { + sqltypes.DateTime : InfoDateTime, + sqltypes.Time: InfoTime, + sqltypes.Boolean : InfoBoolean, +} + + +ischema_names = { + 0 : sqltypes.CHAR, # CHAR + 1 : sqltypes.SMALLINT, # SMALLINT + 2 : sqltypes.INTEGER, # INT + 3 : sqltypes.FLOAT, # Float + 3 : sqltypes.Float, # SmallFloat + 5 : sqltypes.DECIMAL, # DECIMAL + 6 : sqltypes.Integer, # Serial + 7 : sqltypes.DATE, # DATE + 8 : sqltypes.Numeric, # MONEY + 10 : sqltypes.DATETIME, # DATETIME + 11 : sqltypes.Binary, # BYTE + 12 : sqltypes.TEXT, # TEXT + 13 : sqltypes.VARCHAR, # VARCHAR + 15 : sqltypes.NCHAR, # NCHAR + 16 : sqltypes.NVARCHAR, # NVARCHAR + 17 : sqltypes.Integer, # INT8 + 18 : sqltypes.Integer, # Serial8 + 43 : sqltypes.String, # LVARCHAR + -1 : sqltypes.BLOB, # BLOB + -1 : sqltypes.CLOB, # CLOB +} + + +class InfoTypeCompiler(compiler.GenericTypeCompiler): + def visit_DATETIME(self, type_): + return "DATETIME YEAR TO SECOND" + + def visit_TIME(self, type_): + return "DATETIME HOUR TO SECOND" + + def visit_binary(self, type_): + return "BYTE" + + def visit_boolean(self, type_): + return "SMALLINT" + +class InfoSQLCompiler(compiler.SQLCompiler): + + def __init__(self, *args, **kwargs): + self.limit = 0 + self.offset = 0 + + compiler.SQLCompiler.__init__( self , *args, **kwargs ) + + def default_from(self): + return " from systables where tabname = 'systables' " + + def get_select_precolumns( self , select ): + s = select._distinct and "DISTINCT " or "" + # only has limit + if select._limit: + off = select._offset or 0 + s += " FIRST %s " % ( select._limit + off ) + else: + s += "" + return s + + def visit_select(self, select): + if select._offset: + self.offset = select._offset + self.limit = select._limit or 0 + # the column in order by clause must in select too + + def __label( c ): + try: + return c._label.lower() + except: + return '' + + # TODO: dont modify the original select, generate a new one + a = [ __label(c) for c in select._raw_columns ] + for c in select._order_by_clause.clauses: + if ( __label(c) not in a ): + select.append_column( c ) + + return compiler.SQLCompiler.visit_select(self, select) + + def limit_clause(self, select): + return "" + + def visit_function( self , func ): + if func.name.lower() == 'current_date': + return "today" + elif func.name.lower() == 'current_time': + return "CURRENT HOUR TO SECOND" + elif func.name.lower() in ( 'current_timestamp' , 'now' ): + return "CURRENT YEAR TO SECOND" + else: + return compiler.SQLCompiler.visit_function( self , func ) + + def visit_clauselist(self, list, **kwargs): + return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None]) + +class InfoDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, first_pk=False): + colspec = self.preparer.format_column(column) + if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ + isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk: + colspec += " SERIAL" + self.has_serial = True + else: + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + return colspec + + def post_create_table(self, table): + if hasattr( self , 'has_serial' ): + del self.has_serial + return '' + +class InfoIdentifierPreparer(compiler.IdentifierPreparer): + def __init__(self, dialect): + super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") + + def format_constraint(self, constraint): + # informix doesnt support names for constraints + return '' + + def _requires_quotes(self, value): + return False + +class InformixDialect(default.DefaultDialect): + name = 'informix' + # for informix 7.31 + max_identifier_length = 18 + type_compiler = InfoTypeCompiler + poolclass = pool.SingletonThreadPool + statement_compiler = InfoSQLCompiler + ddl_compiler = InfoDDLCompiler + preparer = InfoIdentifierPreparer + colspecs = colspecs + ischema_names = ischema_names + + def do_begin(self , connect ): + cu = connect.cursor() + cu.execute( 'SET LOCK MODE TO WAIT' ) + #cu.execute( 'SET ISOLATION TO REPEATABLE READ' ) + + def table_names(self, connection, schema): + s = "select tabname from systables" + return [row[0] for row in connection.execute(s)] + + def has_table(self, connection, table_name, schema=None): + cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() ) + return bool( cursor.fetchone() is not None ) + + def reflecttable(self, connection, table, include_columns): + c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() ) + rows = c.fetchall() + if not rows : + raise exc.NoSuchTableError(table.name) + else: + if table.owner is not None: + if table.owner.lower() in [r[0] for r in rows]: + owner = table.owner.lower() + else: + raise AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name)) + else: + if len(rows)==1: + owner = rows[0][0] + else: + raise AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name) + + c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3 + where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? + and t3.tabid = t2.tabid and t3.colno = t1.colno + order by t1.colno""", table.name.lower(), owner ) + rows = c.fetchall() + + if not rows: + raise exc.NoSuchTableError(table.name) + + for name , colattr , collength , default , colno in rows: + name = name.lower() + if include_columns and name not in include_columns: + continue + + # in 7.31, coltype = 0x000 + # ^^-- column type + # ^-- 1 not null , 0 null + nullable , coltype = divmod( colattr , 256 ) + if coltype not in ( 0 , 13 ) and default: + default = default.split()[-1] + + if coltype == 0 or coltype == 13: # char , varchar + coltype = ischema_names.get(coltype, InfoString)(collength) + if default: + default = "'%s'" % default + elif coltype == 5: # decimal + precision , scale = ( collength & 0xFF00 ) >> 8 , collength & 0xFF + if scale == 255: + scale = 0 + coltype = InfoNumeric(precision, scale) + else: + try: + coltype = ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, name)) + coltype = sqltypes.NULLTYPE + + colargs = [] + if default is not None: + colargs.append(schema.DefaultClause(sql.text(default))) + + table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs)) + + # FK + c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , + t4.colname as local_column , t7.tabname as remote_table , + t6.colname as remote_column + from sysconstraints as t1 , systables as t2 , + sysindexes as t3 , syscolumns as t4 , + sysreferences as t5 , syscolumns as t6 , systables as t7 , + sysconstraints as t8 , sysindexes as t9 + where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'R' + and t3.tabid = t2.tabid and t3.idxname = t1.idxname + and t4.tabid = t2.tabid and t4.colno = t3.part1 + and t5.constrid = t1.constrid and t8.constrid = t5.primary + and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname + and t7.tabid = t5.ptabid""", table.name.lower(), owner ) + rows = c.fetchall() + fks = {} + for cons_name, cons_type, local_column, remote_table, remote_column in rows: + try: + fk = fks[cons_name] + except KeyError: + fk = ([], []) + fks[cons_name] = fk + refspec = ".".join([remote_table, remote_column]) + schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection) + if local_column not in fk[0]: + fk[0].append(local_column) + if refspec not in fk[1]: + fk[1].append(refspec) + + for name, value in fks.iteritems(): + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1] , None, link_to_name=True )) + + # PK + c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , + t4.colname as local_column + from sysconstraints as t1 , systables as t2 , + sysindexes as t3 , syscolumns as t4 + where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'P' + and t3.tabid = t2.tabid and t3.idxname = t1.idxname + and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower(), owner ) + rows = c.fetchall() + for cons_name, cons_type, local_column in rows: + table.primary_key.add( table.c[local_column] ) + diff --git a/lib/sqlalchemy/dialects/informix/informixdb.py b/lib/sqlalchemy/dialects/informix/informixdb.py new file mode 100644 index 000000000..4e929e024 --- /dev/null +++ b/lib/sqlalchemy/dialects/informix/informixdb.py @@ -0,0 +1,79 @@ +from sqlalchemy.dialects.informix.base import InformixDialect +from sqlalchemy.engine import default + +# for offset + +class informix_cursor(object): + def __init__( self , con ): + self.__cursor = con.cursor() + self.rowcount = 0 + + def offset( self , n ): + if n > 0: + self.fetchmany( n ) + self.rowcount = self.__cursor.rowcount - n + if self.rowcount < 0: + self.rowcount = 0 + else: + self.rowcount = self.__cursor.rowcount + + def execute( self , sql , params ): + if params is None or len( params ) == 0: + params = [] + + return self.__cursor.execute( sql , params ) + + def __getattr__( self , name ): + if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ): + return getattr( self.__cursor , name ) + + +class InfoExecutionContext(default.DefaultExecutionContext): + # cursor.sqlerrd + # 0 - estimated number of rows returned + # 1 - serial value after insert or ISAM error code + # 2 - number of rows processed + # 3 - estimated cost + # 4 - offset of the error into the SQL statement + # 5 - rowid after insert + def post_exec(self): + if getattr(self.compiled, "isinsert", False) and self.inserted_primary_key is None: + self._last_inserted_ids = [self.cursor.sqlerrd[1]] + elif hasattr( self.compiled , 'offset' ): + self.cursor.offset( self.compiled.offset ) + + def create_cursor( self ): + return informix_cursor( self.connection.connection ) + + +class Informix_informixdb(InformixDialect): + driver = 'informixdb' + default_paramstyle = 'qmark' + execution_context_cls = InfoExecutionContext + + @classmethod + def dbapi(cls): + return __import__('informixdb') + + def create_connect_args(self, url): + if url.host: + dsn = '%s@%s' % ( url.database , url.host ) + else: + dsn = url.database + + if url.username: + opt = { 'user':url.username , 'password': url.password } + else: + opt = {} + + return ([dsn], opt) + + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + else: + return False + + +dialect = Informix_informixdb
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/maxdb/__init__.py b/lib/sqlalchemy/dialects/maxdb/__init__.py new file mode 100644 index 000000000..3f12448b7 --- /dev/null +++ b/lib/sqlalchemy/dialects/maxdb/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.maxdb import base, sapdb + +base.dialect = sapdb.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py new file mode 100644 index 000000000..1ec95e03b --- /dev/null +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -0,0 +1,1054 @@ +# maxdb.py +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the MaxDB database. + +This dialect is *not* tested on SQLAlchemy 0.6. + +Overview +-------- + +The ``maxdb`` dialect is **experimental** and has only been tested on 7.6.03.007 +and 7.6.00.037. Of these, **only 7.6.03.007 will work** with SQLAlchemy's ORM. +The earlier version has severe ``LEFT JOIN`` limitations and will return +incorrect results from even very simple ORM queries. + +Only the native Python DB-API is currently supported. ODBC driver support +is a future enhancement. + +Connecting +---------- + +The username is case-sensitive. If you usually connect to the +database with sqlcli and other tools in lower case, you likely need to +use upper case for DB-API. + +Implementation Notes +-------------------- + +Also check the DatabaseNotes page on the wiki for detailed information. + +With the 7.6.00.37 driver and Python 2.5, it seems that all DB-API +generated exceptions are broken and can cause Python to crash. + +For 'somecol.in_([])' to work, the IN operator's generation must be changed +to cast 'NULL' to a numeric, i.e. NUM(NULL). The DB-API doesn't accept a +bind parameter there, so that particular generation must inline the NULL value, +which depends on [ticket:807]. + +The DB-API is very picky about where bind params may be used in queries. + +Bind params for some functions (e.g. MOD) need type information supplied. +The dialect does not yet do this automatically. + +Max will occasionally throw up 'bad sql, compile again' exceptions for +perfectly valid SQL. The dialect does not currently handle these, more +research is needed. + +MaxDB 7.5 and Sap DB <= 7.4 reportedly do not support schemas. A very +slightly different version of this dialect would be required to support +those versions, and can easily be added if there is demand. Some other +required components such as an Max-aware 'old oracle style' join compiler +(thetas with (+) outer indicators) are already done and available for +integration- email the devel list if you're interested in working on +this. + +""" +import datetime, itertools, re + +from sqlalchemy import exc, schema, sql, util +from sqlalchemy.sql import operators as sql_operators, expression as sql_expr +from sqlalchemy.sql import compiler, visitors +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import types as sqltypes + + +class _StringType(sqltypes.String): + _type = None + + def __init__(self, length=None, encoding=None, **kw): + super(_StringType, self).__init__(length=length, **kw) + self.encoding = encoding + + def bind_processor(self, dialect): + if self.encoding == 'unicode': + return None + else: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + else: + return value + return process + + def result_processor(self, dialect): + def process(value): + while True: + if value is None: + return None + elif isinstance(value, unicode): + return value + elif isinstance(value, str): + if self.convert_unicode or dialect.convert_unicode: + return value.decode(dialect.encoding) + else: + return value + elif hasattr(value, 'read'): + # some sort of LONG, snarf and retry + value = value.read(value.remainingLength()) + continue + else: + # unexpected type, return as-is + return value + return process + + +class MaxString(_StringType): + _type = 'VARCHAR' + + def __init__(self, *a, **kw): + super(MaxString, self).__init__(*a, **kw) + + +class MaxUnicode(_StringType): + _type = 'VARCHAR' + + def __init__(self, length=None, **kw): + super(MaxUnicode, self).__init__(length=length, encoding='unicode') + + +class MaxChar(_StringType): + _type = 'CHAR' + + +class MaxText(_StringType): + _type = 'LONG' + + def __init__(self, *a, **kw): + super(MaxText, self).__init__(*a, **kw) + + def get_col_spec(self): + spec = 'LONG' + if self.encoding is not None: + spec = ' '.join((spec, self.encoding)) + elif self.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + + +class MaxNumeric(sqltypes.Numeric): + """The FIXED (also NUMERIC, DECIMAL) data type.""" + + def __init__(self, precision=None, scale=None, **kw): + kw.setdefault('asdecimal', True) + super(MaxNumeric, self).__init__(scale=scale, precision=precision, + **kw) + + def bind_processor(self, dialect): + return None + +class MaxTimestamp(sqltypes.DateTime): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y%m%d%H%M%S" + ("%06u" % ms)) + elif dialect.datetimeformat == 'iso': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms)) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[4:6], value[6:8], + value[8:10], value[10:12], value[12:14], + value[14:])]) + elif dialect.datetimeformat == 'iso': + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[5:7], value[8:10], + value[11:13], value[14:16], value[17:19], + value[20:])]) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%Y%m%d") + elif dialect.datetimeformat == 'iso': + return value.strftime("%Y-%m-%d") + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + return datetime.date( + *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) + elif dialect.datetimeformat == 'iso': + return datetime.date( + *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxTime(sqltypes.Time): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%H%M%S") + elif dialect.datetimeformat == 'iso': + return value.strftime("%H-%M-%S") + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + t = datetime.time( + *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) + return t + elif dialect.datetimeformat == 'iso': + return datetime.time( + *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxBlob(sqltypes.Binary): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return str(value) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return value.read(value.remainingLength()) + return process + +class MaxDBTypeCompiler(compiler.GenericTypeCompiler): + def _string_spec(self, string_spec, type_): + if type_.length is None: + spec = 'LONG' + else: + spec = '%s(%s)' % (string_spec, type_.length) + + if getattr(type_, 'encoding'): + spec = ' '.join([spec, getattr(type_, 'encoding').upper()]) + return spec + + def visit_text(self, type_): + spec = 'LONG' + if getattr(type_, 'encoding', None): + spec = ' '.join((spec, type_.encoding)) + elif type_.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + + def visit_char(self, type_): + return self._string_spec("CHAR", type_) + + def visit_string(self, type_): + return self._string_spec("VARCHAR", type_) + + def visit_binary(self, type_): + return "LONG BYTE" + + def visit_numeric(self, type_): + if type_.scale and type_.precision: + return 'FIXED(%s, %s)' % (type_.precision, type_.scale) + elif type_.precision: + return 'FIXED(%s)' % type_.precision + else: + return 'INTEGER' + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + +colspecs = { + sqltypes.Numeric: MaxNumeric, + sqltypes.DateTime: MaxTimestamp, + sqltypes.Date: MaxDate, + sqltypes.Time: MaxTime, + sqltypes.String: MaxString, + sqltypes.Unicode:MaxUnicode, + sqltypes.Binary: MaxBlob, + sqltypes.Text: MaxText, + sqltypes.CHAR: MaxChar, + sqltypes.TIMESTAMP: MaxTimestamp, + sqltypes.BLOB: MaxBlob, + sqltypes.Unicode: MaxUnicode, + } + +ischema_names = { + 'boolean': sqltypes.BOOLEAN, + 'char': sqltypes.CHAR, + 'character': sqltypes.CHAR, + 'date': sqltypes.DATE, + 'fixed': sqltypes.Numeric, + 'float': sqltypes.FLOAT, + 'int': sqltypes.INT, + 'integer': sqltypes.INT, + 'long binary': sqltypes.BLOB, + 'long unicode': sqltypes.Text, + 'long': sqltypes.Text, + 'long': sqltypes.Text, + 'smallint': sqltypes.SmallInteger, + 'time': sqltypes.Time, + 'timestamp': sqltypes.TIMESTAMP, + 'varchar': sqltypes.VARCHAR, + } + +# TODO: migrate this to sapdb.py +class MaxDBExecutionContext(default.DefaultExecutionContext): + def post_exec(self): + # DB-API bug: if there were any functions as values, + # then do another select and pull CURRVAL from the + # autoincrement column's implicit sequence... ugh + if self.compiled.isinsert and not self.executemany: + table = self.compiled.statement.table + index, serial_col = _autoserial_column(table) + + if serial_col and (not self.compiled._safeserial or + not(self._last_inserted_ids) or + self._last_inserted_ids[index] in (None, 0)): + if table.schema: + sql = "SELECT %s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + else: + sql = "SELECT CURRENT_SCHEMA.%s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + + if self.connection.engine._should_log_info: + self.connection.engine.logger.info(sql) + rs = self.cursor.execute(sql) + id = rs.fetchone()[0] + + if self.connection.engine._should_log_debug: + self.connection.engine.logger.debug([id]) + if not self._last_inserted_ids: + # This shouldn't ever be > 1? Right? + self._last_inserted_ids = \ + [None] * len(table.primary_key.columns) + self._last_inserted_ids[index] = id + + super(MaxDBExecutionContext, self).post_exec() + + def get_result_proxy(self): + if self.cursor.description is not None: + for column in self.cursor.description: + if column[1] in ('Long Binary', 'Long', 'Long Unicode'): + return MaxDBResultProxy(self) + return engine_base.ResultProxy(self) + + @property + def rowcount(self): + if hasattr(self, '_rowcount'): + return self._rowcount + else: + return self.cursor.rowcount + +class MaxDBCachedColumnRow(engine_base.RowProxy): + """A RowProxy that only runs result_processors once per column.""" + + def __init__(self, parent, row): + super(MaxDBCachedColumnRow, self).__init__(parent, row) + self.columns = {} + self._row = row + self._parent = parent + + def _get_col(self, key): + if key not in self.columns: + self.columns[key] = self._parent._get_col(self._row, key) + return self.columns[key] + + def __iter__(self): + for i in xrange(len(self._row)): + yield self._get_col(i) + + def __repr__(self): + return repr(list(self)) + + def __eq__(self, other): + return ((other is self) or + (other == tuple([self._get_col(key) + for key in xrange(len(self._row))]))) + def __getitem__(self, key): + if isinstance(key, slice): + indices = key.indices(len(self._row)) + return tuple([self._get_col(i) for i in xrange(*indices)]) + else: + return self._get_col(key) + + def __getattr__(self, name): + try: + return self._get_col(name) + except KeyError: + raise AttributeError(name) + + +class MaxDBResultProxy(engine_base.ResultProxy): + _process_row = MaxDBCachedColumnRow + +class MaxDBCompiler(compiler.SQLCompiler): + + function_conversion = { + 'CURRENT_DATE': 'DATE', + 'CURRENT_TIME': 'TIME', + 'CURRENT_TIMESTAMP': 'TIMESTAMP', + } + + # These functions must be written without parens when called with no + # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL' + bare_functions = set([ + 'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP', + 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP', + 'UTCDATE', 'UTCDIFF']) + + def visit_mod(self, binary, **kw): + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def default_from(self): + return ' FROM DUAL' + + def for_update_clause(self, select): + clause = select.for_update + if clause is True: + return " WITH LOCK EXCLUSIVE" + elif clause is None: + return "" + elif clause == "read": + return " WITH LOCK" + elif clause == "ignore": + return " WITH LOCK (IGNORE) EXCLUSIVE" + elif clause == "nowait": + return " WITH LOCK (NOWAIT) EXCLUSIVE" + elif isinstance(clause, basestring): + return " WITH LOCK %s" % clause.upper() + elif not clause: + return "" + else: + return " WITH LOCK EXCLUSIVE" + + def function_argspec(self, fn, **kw): + if fn.name.upper() in self.bare_functions: + return "" + elif len(fn.clauses) > 0: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def visit_function(self, fn, **kw): + transform = self.function_conversion.get(fn.name.upper(), None) + if transform: + fn = fn._clone() + fn.name = transform + return super(MaxDBCompiler, self).visit_function(fn, **kw) + + def visit_cast(self, cast, **kwargs): + # MaxDB only supports casts * to NUMERIC, * to VARCHAR or + # date/time to VARCHAR. Casts of LONGs will fail. + if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)): + return "NUM(%s)" % self.process(cast.clause) + elif isinstance(cast.type, sqltypes.String): + return "CHR(%s)" % self.process(cast.clause) + else: + return self.process(cast.clause) + + def visit_sequence(self, sequence): + if sequence.optional: + return None + else: + return (self.dialect.identifier_preparer.format_sequence(sequence) + + ".NEXTVAL") + + class ColumnSnagger(visitors.ClauseVisitor): + def __init__(self): + self.count = 0 + self.column = None + def visit_column(self, column): + self.column = column + self.count += 1 + + def _find_labeled_columns(self, columns, use_labels=False): + labels = {} + for column in columns: + if isinstance(column, basestring): + continue + snagger = self.ColumnSnagger() + snagger.traverse(column) + if snagger.count == 1: + if isinstance(column, sql_expr._Label): + labels[unicode(snagger.column)] = column.name + elif use_labels: + labels[unicode(snagger.column)] = column._label + + return labels + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # ORDER BY clauses in DISTINCT queries must reference aliased + # inner columns by alias name, not true column name. + if order_by and getattr(select, '_distinct', False): + labels = self._find_labeled_columns(select.inner_columns, + select.use_labels) + if labels: + for needs_alias in labels.keys(): + r = re.compile(r'(^| )(%s)(,| |$)' % + re.escape(needs_alias)) + order_by = r.sub((r'\1%s\3' % labels[needs_alias]), + order_by) + + # No ORDER BY in subqueries. + if order_by: + if self.is_subquery(): + # It's safe to simply drop the ORDER BY if there is no + # LIMIT. Right? Other dialects seem to get away with + # dropping order. + if select._limit: + raise exc.InvalidRequestError( + "MaxDB does not support ORDER BY in subqueries") + else: + return "" + return " ORDER BY " + order_by + else: + return "" + + def get_select_precolumns(self, select): + # Convert a subquery's LIMIT to TOP + sql = select._distinct and 'DISTINCT ' or '' + if self.is_subquery() and select._limit: + if select._offset: + raise exc.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + sql += 'TOP %s ' % select._limit + return sql + + def limit_clause(self, select): + # The docs say offsets are supported with LIMIT. But they're not. + # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate? + if self.is_subquery(): + # sub queries need TOP + return '' + elif select._offset: + raise exc.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + else: + return ' \n LIMIT %s' % (select._limit,) + + def visit_insert(self, insert): + self.isinsert = True + self._safeserial = True + + colparams = self._get_colparams(insert) + for value in (insert.parameters or {}).itervalues(): + if isinstance(value, sql_expr.Function): + self._safeserial = False + break + + return ''.join(('INSERT INTO ', + self.preparer.format_table(insert.table), + ' (', + ', '.join([self.preparer.format_column(c[0]) + for c in colparams]), + ') VALUES (', + ', '.join([c[1] for c in colparams]), + ')')) + + +class MaxDBDefaultRunner(engine_base.DefaultRunner): + def visit_sequence(self, seq): + if seq.optional: + return None + return self.execute_string("SELECT %s.NEXTVAL FROM DUAL" % ( + self.dialect.identifier_preparer.format_sequence(seq))) + + +class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = set([ + 'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha', + 'alter', 'any', 'ascii', 'asin', 'atan', 'atan2', 'avg', 'binary', + 'bit', 'boolean', 'byte', 'case', 'ceil', 'ceiling', 'char', + 'character', 'check', 'chr', 'column', 'concat', 'constraint', 'cos', + 'cosh', 'cot', 'count', 'cross', 'curdate', 'current', 'curtime', + 'database', 'date', 'datediff', 'day', 'dayname', 'dayofmonth', + 'dayofweek', 'dayofyear', 'dec', 'decimal', 'decode', 'default', + 'degrees', 'delete', 'digits', 'distinct', 'double', 'except', + 'exists', 'exp', 'expand', 'first', 'fixed', 'float', 'floor', 'for', + 'from', 'full', 'get_objectname', 'get_schema', 'graphic', 'greatest', + 'group', 'having', 'hex', 'hextoraw', 'hour', 'ifnull', 'ignore', + 'index', 'initcap', 'inner', 'insert', 'int', 'integer', 'internal', + 'intersect', 'into', 'join', 'key', 'last', 'lcase', 'least', 'left', + 'length', 'lfill', 'list', 'ln', 'locate', 'log', 'log10', 'long', + 'longfile', 'lower', 'lpad', 'ltrim', 'makedate', 'maketime', + 'mapchar', 'max', 'mbcs', 'microsecond', 'min', 'minute', 'mod', + 'month', 'monthname', 'natural', 'nchar', 'next', 'no', 'noround', + 'not', 'now', 'null', 'num', 'numeric', 'object', 'of', 'on', + 'order', 'packed', 'pi', 'power', 'prev', 'primary', 'radians', + 'real', 'reject', 'relative', 'replace', 'rfill', 'right', 'round', + 'rowid', 'rowno', 'rpad', 'rtrim', 'second', 'select', 'selupd', + 'serial', 'set', 'show', 'sign', 'sin', 'sinh', 'smallint', 'some', + 'soundex', 'space', 'sqrt', 'stamp', 'statistics', 'stddev', + 'subdate', 'substr', 'substring', 'subtime', 'sum', 'sysdba', + 'table', 'tan', 'tanh', 'time', 'timediff', 'timestamp', 'timezone', + 'to', 'toidentifier', 'transaction', 'translate', 'trim', 'trunc', + 'truncate', 'ucase', 'uid', 'unicode', 'union', 'update', 'upper', + 'user', 'usergroup', 'using', 'utcdate', 'utcdiff', 'value', 'values', + 'varchar', 'vargraphic', 'variance', 'week', 'weekofyear', 'when', + 'where', 'with', 'year', 'zoned' ]) + + def _normalize_name(self, name): + if name is None: + return None + if name.isupper(): + lc_name = name.lower() + if not self._requires_quotes(lc_name): + return lc_name + return name + + def _denormalize_name(self, name): + if name is None: + return None + elif (name.islower() and + not self._requires_quotes(name)): + return name.upper() + else: + return name + + def _maybe_quote_identifier(self, name): + if self._requires_quotes(name): + return self.quote_identifier(name) + else: + return name + + +class MaxDBDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kw): + colspec = [self.preparer.format_column(column), + self.dialect.type_compiler.process(column.type)] + + if not column.nullable: + colspec.append('NOT NULL') + + default = column.default + default_str = self.get_column_default_string(column) + + # No DDL default for columns specified with non-optional sequence- + # this defaulting behavior is entirely client-side. (And as a + # consequence, non-reflectable.) + if (default and isinstance(default, schema.Sequence) and + not default.optional): + pass + # Regular default + elif default_str is not None: + colspec.append('DEFAULT %s' % default_str) + # Assign DEFAULT SERIAL heuristically + elif column.primary_key and column.autoincrement: + # For SERIAL on a non-primary key member, use + # DefaultClause(text('SERIAL')) + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + (isinstance(c.type, sqltypes.Integer) or + (isinstance(c.type, MaxNumeric) and + c.type.precision)) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('DEFAULT SERIAL') + except IndexError: + pass + return ' '.join(colspec) + + def get_column_default_string(self, column): + if isinstance(column.server_default, schema.DefaultClause): + if isinstance(column.default.arg, basestring): + if isinstance(column.type, sqltypes.Integer): + return str(column.default.arg) + else: + return "'%s'" % column.default.arg + else: + return unicode(self._compile(column.default.arg, None)) + else: + return None + + def visit_create_sequence(self, create): + """Creates a SEQUENCE. + + TODO: move to module doc? + + start + With an integer value, set the START WITH option. + + increment + An integer value to increment by. Default is the database default. + + maxdb_minvalue + maxdb_maxvalue + With an integer value, sets the corresponding sequence option. + + maxdb_no_minvalue + maxdb_no_maxvalue + Defaults to False. If true, sets the corresponding sequence option. + + maxdb_cycle + Defaults to False. If true, sets the CYCLE option. + + maxdb_cache + With an integer value, sets the CACHE option. + + maxdb_no_cache + Defaults to False. If true, sets NOCACHE. + """ + sequence = create.element + + if (not sequence.optional and + (not self.checkfirst or + not self.dialect.has_sequence(self.connection, sequence.name))): + + ddl = ['CREATE SEQUENCE', + self.preparer.format_sequence(sequence)] + + sequence.increment = 1 + + if sequence.increment is not None: + ddl.extend(('INCREMENT BY', str(sequence.increment))) + + if sequence.start is not None: + ddl.extend(('START WITH', str(sequence.start))) + + opts = dict([(pair[0][6:].lower(), pair[1]) + for pair in sequence.kwargs.items() + if pair[0].startswith('maxdb_')]) + + if 'maxvalue' in opts: + ddl.extend(('MAXVALUE', str(opts['maxvalue']))) + elif opts.get('no_maxvalue', False): + ddl.append('NOMAXVALUE') + if 'minvalue' in opts: + ddl.extend(('MINVALUE', str(opts['minvalue']))) + elif opts.get('no_minvalue', False): + ddl.append('NOMINVALUE') + + if opts.get('cycle', False): + ddl.append('CYCLE') + + if 'cache' in opts: + ddl.extend(('CACHE', str(opts['cache']))) + elif opts.get('no_cache', False): + ddl.append('NOCACHE') + + return ' '.join(ddl) + + +class MaxDBDialect(default.DefaultDialect): + name = 'maxdb' + supports_alter = True + supports_unicode_statements = True + max_identifier_length = 32 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + preparer = MaxDBIdentifierPreparer + statement_compiler = MaxDBCompiler + ddl_compiler = MaxDBDDLCompiler + defaultrunner = MaxDBDefaultRunner + execution_ctx_cls = MaxDBExecutionContext + + colspecs = colspecs + ischema_names = ischema_names + + # MaxDB-specific + datetimeformat = 'internal' + + def __init__(self, _raise_known_sql_errors=False, **kw): + super(MaxDBDialect, self).__init__(**kw) + self._raise_known = _raise_known_sql_errors + + if self.dbapi is None: + self.dbapi_type_map = {} + else: + self.dbapi_type_map = { + 'Long Binary': MaxBlob(), + 'Long byte_t': MaxBlob(), + 'Long Unicode': MaxText(), + 'Timestamp': MaxTimestamp(), + 'Date': MaxDate(), + 'Time': MaxTime(), + datetime.datetime: MaxTimestamp(), + datetime.date: MaxDate(), + datetime.time: MaxTime(), + } + + def do_execute(self, cursor, statement, parameters, context=None): + res = cursor.execute(statement, parameters) + if isinstance(res, int) and context is not None: + context._rowcount = res + + def do_release_savepoint(self, connection, name): + # Does MaxDB truly support RELEASE SAVEPOINT <id>? All my attempts + # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS + # BEGIN SQLSTATE: I7065" + # Note that ROLLBACK TO works fine. In theory, a RELEASE should + # just free up some transactional resources early, before the overall + # COMMIT/ROLLBACK so omitting it should be relatively ok. + pass + + def get_default_schema_name(self, connection): + try: + return self._default_schema_name + except AttributeError: + name = self.identifier_preparer._normalize_name( + connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) + self._default_schema_name = name + return name + + def has_table(self, connection, table_name, schema=None): + denormalize = self.identifier_preparer._denormalize_name + bind = [denormalize(table_name)] + if schema is None: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME=? AND" + " TABLES.SCHEMANAME=CURRENT_SCHEMA ") + else: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME = ? AND" + " TABLES.SCHEMANAME=? ") + bind.append(denormalize(schema)) + + rp = connection.execute(sql, bind) + found = bool(rp.fetchone()) + rp.close() + return found + + def table_names(self, connection, schema): + if schema is None: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=CURRENT_SCHEMA ") + rs = connection.execute(sql) + else: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=? ") + matchname = self.identifier_preparer._denormalize_name(schema) + rs = connection.execute(sql, matchname) + normalize = self.identifier_preparer._normalize_name + return [normalize(row[0]) for row in rs] + + def reflecttable(self, connection, table, include_columns): + denormalize = self.identifier_preparer._denormalize_name + normalize = self.identifier_preparer._normalize_name + + st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' + ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' + 'FROM COLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY POS') + + fk = ('SELECT COLUMNNAME, FKEYNAME, ' + ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' + ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' + ' THEN 1 ELSE 0 END) AS in_schema ' + 'FROM FOREIGNKEYCOLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY FKEYNAME ') + + params = [denormalize(table.name)] + if not table.schema: + st = st % 'CURRENT_SCHEMA' + fk = fk % 'CURRENT_SCHEMA' + else: + st = st % '?' + fk = fk % '?' + params.append(denormalize(table.schema)) + + rows = connection.execute(st, params).fetchall() + if not rows: + raise exc.NoSuchTableError(table.fullname) + + include_columns = set(include_columns or []) + + for row in rows: + (name, mode, col_type, encoding, length, scale, + nullable, constant_def, func_def) = row + + name = normalize(name) + + if include_columns and name not in include_columns: + continue + + type_args, type_kw = [], {} + if col_type == 'FIXED': + type_args = length, scale + # Convert FIXED(10) DEFAULT SERIAL to our Integer + if (scale == 0 and + func_def is not None and func_def.startswith('SERIAL')): + col_type = 'INTEGER' + type_args = length, + elif col_type in 'FLOAT': + type_args = length, + elif col_type in ('CHAR', 'VARCHAR'): + type_args = length, + type_kw['encoding'] = encoding + elif col_type == 'LONG': + type_kw['encoding'] = encoding + + try: + type_cls = ischema_names[col_type.lower()] + type_instance = type_cls(*type_args, **type_kw) + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (col_type, name)) + type_instance = sqltypes.NullType + + col_kw = {'autoincrement': False} + col_kw['nullable'] = (nullable == 'YES') + col_kw['primary_key'] = (mode == 'KEY') + + if func_def is not None: + if func_def.startswith('SERIAL'): + if col_kw['primary_key']: + # No special default- let the standard autoincrement + # support handle SERIAL pk columns. + col_kw['autoincrement'] = True + else: + # strip current numbering + col_kw['server_default'] = schema.DefaultClause( + sql.text('SERIAL')) + col_kw['autoincrement'] = True + else: + col_kw['server_default'] = schema.DefaultClause( + sql.text(func_def)) + elif constant_def is not None: + col_kw['server_default'] = schema.DefaultClause(sql.text( + "'%s'" % constant_def.replace("'", "''"))) + + table.append_column(schema.Column(name, type_instance, **col_kw)) + + fk_sets = itertools.groupby(connection.execute(fk, params), + lambda row: row.FKEYNAME) + for fkeyname, fkey in fk_sets: + fkey = list(fkey) + if include_columns: + key_cols = set([r.COLUMNNAME for r in fkey]) + if key_cols != include_columns: + continue + + columns, referants = [], [] + quote = self.identifier_preparer._maybe_quote_identifier + + for row in fkey: + columns.append(normalize(row.COLUMNNAME)) + if table.schema or not row.in_schema: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFSCHEMANAME', 'REFTABLENAME', + 'REFCOLUMNNAME')])) + else: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) + + constraint_kw = {'name': fkeyname.lower()} + if fkey[0].RULE is not None: + rule = fkey[0].RULE + if rule.startswith('DELETE '): + rule = rule[7:] + constraint_kw['ondelete'] = rule + + table_kw = {} + if table.schema or not row.in_schema: + table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) + + ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), + table_kw.get('schema')) + if ref_key not in table.metadata.tables: + schema.Table(normalize(fkey[0].REFTABLENAME), + table.metadata, + autoload=True, autoload_with=connection, + **table_kw) + + constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True, + **constraint_kw) + table.append_constraint(constraint) + + def has_sequence(self, connection, name): + # [ticket:726] makes this schema-aware. + denormalize = self.identifier_preparer._denormalize_name + sql = ("SELECT sequence_name FROM SEQUENCES " + "WHERE SEQUENCE_NAME=? ") + + rp = connection.execute(sql, denormalize(name)) + found = bool(rp.fetchone()) + rp.close() + return found + + + +def _autoserial_column(table): + """Finds the effective DEFAULT SERIAL column of a Table, if any.""" + + for index, col in enumerate(table.primary_key.columns): + if (isinstance(col.type, (sqltypes.Integer, sqltypes.Numeric)) and + col.autoincrement): + if isinstance(col.default, schema.Sequence): + if col.default.optional: + return index, col + elif (col.default is None or + (not isinstance(col.server_default, schema.DefaultClause))): + return index, col + + return None, None + diff --git a/lib/sqlalchemy/dialects/maxdb/sapdb.py b/lib/sqlalchemy/dialects/maxdb/sapdb.py new file mode 100644 index 000000000..10e61228e --- /dev/null +++ b/lib/sqlalchemy/dialects/maxdb/sapdb.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.maxdb.base import MaxDBDialect + +class MaxDB_sapdb(MaxDBDialect): + driver = 'sapdb' + + @classmethod + def dbapi(cls): + from sapdb import dbapi as _dbapi + return _dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + return [], opts + + +dialect = MaxDB_sapdb
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 000000000..e3a829047 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql + +base.dialect = pyodbc.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py new file mode 100644 index 000000000..10b8b33b3 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -0,0 +1,51 @@ +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect +import sys + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) + return value + return process + + +class MSDialect_adodbapi(MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = True + driver = 'adodbapi' + + @classmethod + def import_dbapi(cls): + import adodbapi as module + return module + + colspecs = MSDialect.colspecs.copy() + colspecs[sqltypes.DateTime] = MSDateTime_adodbapi + + def create_connect_args(self, url): + keys = url.query + + connectors = ["Provider=SQLOLEDB"] + if 'port' in keys: + connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) + else: + connectors.append ("Data Source=%s" % keys.get("host")) + connectors.append ("Initial Catalog=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("User Id=%s" % user) + connectors.append("Password=%s" % keys.get("password", "")) + else: + connectors.append("Integrated Security=SSPI") + return [[";".join (connectors)], {}] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) + +dialect = MSDialect_adodbapi diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 000000000..cd031af40 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,1448 @@ +# mssql.py + +"""Support for the Microsoft SQL Server database. + +Driver +------ + +The MSSQL dialect will work with three different available drivers: + +* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded + driver. + +* *pymssql* - http://pymssql.sourceforge.net/ + +* *adodbapi* - http://adodbapi.sourceforge.net/ + +Drivers are loaded in the order listed above based on availability. + +If you need to load a specific driver pass ``module_name`` when +creating the engine:: + + engine = create_engine('mssql+module_name://dsn') + +``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and +``adodbapi``. + +Currently the pyodbc driver offers the greatest level of +compatibility. + +Connecting +---------- + +Connecting with create_engine() uses the standard URL approach of +``mssql://user:pass@host/dbname[?key=value&key=value...]``. + +If the database name is present, the tokens are converted to a +connection string with the specified values. If the database is not +present, then the host token is taken directly as the DSN name. + +Examples of pyodbc connection string URLs: + +* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``. + The connection string that is created will appear like:: + + dsn=mydsn;TrustedConnection=Yes + +* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named + ``mydsn`` passing in the ``UID`` and ``PWD`` information. The + connection string that is created will appear like:: + + dsn=mydsn;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects + using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` + information, plus the additional connection configuration option + ``LANGUAGE``. The connection string that is created will appear + like:: + + dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english + +* *mssql+pyodbc://user:pass@host/db* - connects using a connection string + dynamically created that would appear like:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection + string that is dynamically created, which also includes the port + information using the comma syntax. If your connection string + requires the port information to be passed as a ``port`` keyword + see the next example. This will create the following connection + string:: + + DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection + string that is dynamically created that includes the port + information as a separate ``port`` keyword. This will create the + following connection string:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 + +If you require a connection string that is outside the options +presented above, use the ``odbc_connect`` keyword to pass in a +urlencoded connection string. What gets passed in will be urldecoded +and passed directly. + +For example:: + + mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb + +would create the following connection string:: + + dsn=mydsn;Database=db + +Encoding your connection string can be easily accomplished through +the python shell. For example:: + + >>> import urllib + >>> urllib.quote_plus('dsn=mydsn;Database=db') + 'dsn%3Dmydsn%3BDatabase%3Ddb' + +Additional arguments which may be specified either as query string +arguments on the URL, or as keyword argument to +:func:`~sqlalchemy.create_engine()` are: + +* *query_timeout* - allows you to override the default query timeout. + Defaults to ``None``. This is only supported on pymssql. + +* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY + should be used in place of the non-scoped version @@IDENTITY. + Defaults to True. + +* *max_identifier_length* - allows you to se the maximum length of + identfiers supported by the database. Defaults to 128. For pymssql + the default is 30. + +* *schema_name* - use to set the schema name. Defaults to ``dbo``. + +Auto Increment Behavior +----------------------- + +``IDENTITY`` columns are supported by using SQLAlchemy +``schema.Sequence()`` objects. In other words:: + + Table('test', mss_engine, + Column('id', Integer, + Sequence('blah',100,10), primary_key=True), + Column('name', String(20)) + ).create() + +would yield:: + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +Note that the ``start`` and ``increment`` values for sequences are +optional and will default to 1,1. + +* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for + ``INSERT`` s) + +* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on + ``INSERT`` + +Collation Support +----------------- + +MSSQL specific string types support a collation parameter that +creates a column-level specific collation for the column. The +collation parameter accepts a Windows Collation Name or a SQL +Collation Name. Supported types are MSChar, MSNChar, MSString, +MSNVarchar, MSText, and MSNText. For example:: + + Column('login', String(32, collation='Latin1_General_CI_AS')) + +will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +LIMIT/OFFSET Support +-------------------- + +MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is +supported directly through the ``TOP`` Transact SQL keyword:: + + select.limit + +will yield:: + + SELECT TOP n + +If using SQL Server 2005 or above, LIMIT with OFFSET +support is available through the ``ROW_NUMBER OVER`` construct. +For versions below 2005, LIMIT with OFFSET usage will fail. + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatibile with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always retrun the database +server version information (in this case SQL2005) and not the +compatibiility level information. Because of this, if running under +a backwards compatibility mode SQAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +Known Issues +------------ + +* No support for more than one ``IDENTITY`` column per table + +* pymssql has problems with binary and unicode data that this module + does **not** work around + +""" +import datetime, decimal, inspect, operator, sys, re +import itertools + +from sqlalchemy import sql, schema as sa_schema, exc, util +from sqlalchemy.sql import select, compiler, expression, \ + operators as sql_operators, \ + functions as sql_functions, util as sql_util +from sqlalchemy.engine import default, base, reflection +from sqlalchemy import types as sqltypes +from decimal import Decimal as _python_Decimal +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ + FLOAT, TIMESTAMP, DATETIME, DATE + + +from sqlalchemy.dialects.mssql import information_schema as ischema + +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = set( + ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', + 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', + 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', + 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', + 'containstable', 'continue', 'convert', 'create', 'cross', 'current', + 'current_date', 'current_time', 'current_timestamp', 'current_user', + 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', + 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', + 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', + 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', + 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', + 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', + 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', + 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', + 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', + 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', + 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', + 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', + 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', + 'reconfigure', 'references', 'replication', 'restore', 'restrict', + 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', + 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', + 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', + 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', + 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', + 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', + 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', + 'writetext', + ]) + + +class _MSNumeric(sqltypes.Numeric): + def result_processor(self, dialect): + if self.asdecimal: + def process(value): + if value is not None: + return _python_Decimal(str(value)) + else: + return value + return process + else: + def process(value): + return float(value) + return process + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, decimal.Decimal): + if value.adjusted() < 0: + result = "%s0.%s%s" % ( + (value < 0 and '-' or ''), + '0' * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value._int])) + + else: + if 'E' in str(value): + result = "%s%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int]), + "0" * (value.adjusted() - (len(value._int)-1))) + else: + if (len(value._int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1]), + "".join([str(s) for s in value._int][value.adjusted() + 1:])) + else: + result = "%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1])) + + return result + + else: + return value + + return process + +class REAL(sqltypes.Float): + """A type for ``real`` numbers.""" + + __visit_name__ = 'REAL' + + def __init__(self): + super(REAL, self).__init__(precision=24) + +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, basestring): + return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super(TIME, self).__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) + return value + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, basestring): + return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + + +class _DateTimeBase(object): + def bind_processor(self, dialect): + def process(value): + # TODO: why ? + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'SMALLDATETIME' + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'DATETIME2' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +# TODO: is this not an Interval ? +class DATETIMEOFFSET(sqltypes.TypeEngine): + __visit_name__ = 'DATETIMEOFFSET' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +class _StringType(object): + """Base for MSSQL string types.""" + + def __init__(self, collation=None): + self.collation = collation + + def __repr__(self): + attributes = inspect.getargspec(self.__init__)[0][1:] + attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) + + params = {} + for attr in attributes: + val = getattr(self, attr) + if val is not None and val is not False: + params[attr] = val + + return "%s(%s)" % (self.__class__.__name__, + ', '.join(['%s=%r' % (k, params[k]) for k in params])) + + +class TEXT(_StringType, sqltypes.TEXT): + """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" + + def __init__(self, *args, **kw): + """Construct a TEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.Text.__init__(self, *args, **kw) + +class NTEXT(_StringType, sqltypes.UnicodeText): + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = 'NTEXT' + + def __init__(self, *args, **kwargs): + """Construct a NTEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kwargs.pop('collation', None) + _StringType.__init__(self, collation) + length = kwargs.pop('length', None) + sqltypes.UnicodeText.__init__(self, length, **kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a VARCHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param assert_unicode: + + If None (the default), no assertion will take place unless + overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. + + If 'warn', will issue a runtime warning if a ``str`` + instance is used as a bind value. + + If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.VARCHAR.__init__(self, *args, **kw) + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MSSQL NVARCHAR type. + + For variable-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a NVARCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NVARCHAR.__init__(self, *args, **kw) + +class CHAR(_StringType, sqltypes.CHAR): + """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a CHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param assert_unicode: + + If None (the default), no assertion will take place unless + overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. + + If 'warn', will issue a runtime warning if a ``str`` + instance is used as a bind value. + + If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.CHAR.__init__(self, *args, **kw) + +class NCHAR(_StringType, sqltypes.NCHAR): + """MSSQL NCHAR type. + + For fixed-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct an NCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NCHAR.__init__(self, *args, **kw) + +class BINARY(sqltypes.Binary): + __visit_name__ = 'BINARY' + +class VARBINARY(sqltypes.Binary): + __visit_name__ = 'VARBINARY' + +class IMAGE(sqltypes.Binary): + __visit_name__ = 'IMAGE' + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + +class _MSBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = 'MONEY' + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = 'SMALLMONEY' + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = 'SQL_VARIANT' + +# old names. +MSNumeric = _MSNumeric +MSDateTime = _MSDateTime +MSDate = _MSDate +MSBoolean = _MSBoolean +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +colspecs = { + sqltypes.Numeric : _MSNumeric, + sqltypes.DateTime : _MSDateTime, + sqltypes.Date : _MSDate, + sqltypes.Time : TIME, + sqltypes.Boolean : _MSBoolean, +} + +ischema_names = { + 'int' : INTEGER, + 'bigint': BIGINT, + 'smallint' : SMALLINT, + 'tinyint' : TINYINT, + 'varchar' : VARCHAR, + 'nvarchar' : NVARCHAR, + 'char' : CHAR, + 'nchar' : NCHAR, + 'text' : TEXT, + 'ntext' : NTEXT, + 'decimal' : DECIMAL, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'datetime' : DATETIME, + 'datetime2' : DATETIME2, + 'datetimeoffset' : DATETIMEOFFSET, + 'date': DATE, + 'time': TIME, + 'smalldatetime' : SMALLDATETIME, + 'binary' : BINARY, + 'varbinary' : VARBINARY, + 'bit': BIT, + 'real' : REAL, + 'image' : IMAGE, + 'timestamp': TIMESTAMP, + 'money': MONEY, + 'smallmoney': SMALLMONEY, + 'uniqueidentifier': UNIQUEIDENTIFIER, + 'sql_variant': SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, 'collation', None): + collation = 'COLLATE %s' % type_.collation + else: + collation = None + + if type_.length: + spec = spec + "(%d)" % type_.length + + return ' '.join([c for c in (spec, collation) + if c is not None]) + + def visit_FLOAT(self, type_): + precision = getattr(type_, 'precision', None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': precision} + + def visit_REAL(self, type_): + return "REAL" + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_DATETIMEOFFSET(self, type_): + if type_.precision: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_TIME(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_DATETIME2(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_): + return "SMALLDATETIME" + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_unicode_text(self, type_): + return self.visit_NTEXT(type_) + + def visit_NTEXT(self, type_): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_): + return self._extend("VARCHAR", type_) + + def visit_CHAR(self, type_): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_): + return self._extend("NVARCHAR", type_) + + def visit_date(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_DATE(type_) + + def visit_time(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_TIME(type_) + + def visit_binary(self, type_): + if type_.length: + return self.visit_BINARY(type_) + else: + return self.visit_IMAGE(type_) + + def visit_BINARY(self, type_): + if type_.length: + return "BINARY(%s)" % type_.length + else: + return "BINARY" + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_VARBINARY(self, type_): + if type_.length: + return "VARBINARY(%s)" % type_.length + else: + return "VARBINARY" + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return 'SMALLMONEY' + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_): + return 'SQL_VARIANT' + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _result_proxy = None + _lastrowid = None + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = seq_column.key in self.compiled_parameters[0] + else: + self._enable_identity_insert = False + + self._select_lastrowid = insert_has_sequence and \ + not self.compiled.returning and \ + not self._enable_identity_insert and \ + not self.executemany + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl)) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + if self._select_lastrowid: + if self.dialect.use_scope_identity: + self.cursor.execute("SELECT scope_identity() AS lastrowid") + else: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchall()[0] # fetchall() ensures the cursor is consumed without closing it + self._lastrowid = int(row[0]) + + if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning: + self._result_proxy = base.FullyBufferedResultProxy(self) + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + + def get_lastrowid(self): + return self._lastrowid + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + except: + pass + + def get_result_proxy(self): + if self._result_proxy: + return self._result_proxy + else: + return base.ResultProxy(self) + +class MSSQLCompiler(compiler.SQLCompiler): + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'microseconds': 'microsecond' + }) + + def __init__(self, *args, **kwargs): + super(MSSQLCompiler, self).__init__(*args, **kwargs) + self.tablealiases = {} + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op(self, binary): + return "%s + %s" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def get_select_precolumns(self, select): + """ MS-SQL puts TOP, it's version of LIMIT here """ + if select._distinct or select._limit: + s = select._distinct and "DISTINCT " or "" + + if select._limit: + if not select._offset: + s += "TOP %s " % (select._limit,) + return s + return compiler.SQLCompiler.get_select_precolumns(self, select) + + def limit_clause(self, select): + # Limit in mssql is after the select keyword + return "" + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + + """ + if not getattr(select, '_mssql_visit', None) and select._offset: + # to use ROW_NUMBER(), an ORDER BY is required. + orderby = self.process(select._order_by_clause) + if not orderby: + raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') + + _offset = select._offset + _limit = select._limit + select._mssql_visit = True + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() + + limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) + limitselect.append_whereclause("mssql_rn>%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) + return self.process(limitselect, iswrapper=True, **kwargs) + else: + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) + else: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + def visit_alias(self, alias, **kwargs): + # translate for schema-qualified table aliases + self.tablealiases[alias.original] = alias + kwargs['mssql_aliased'] = True + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_column(self, column, result_map=None, **kwargs): + if column.table is not None and \ + (not self.isupdate and not self.isdelete) or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = expression._corresponding_column_or_error(t, column) + + if result_map is not None: + result_map[column.name.lower()] = (column.name, (column, ), column.type) + + return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) + + return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ + and not isinstance(binary.right, expression._BindParamClause): + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) + else: + if (binary.operator is operator.eq or binary.operator is operator.ne) and ( + (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \ + (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \ + isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)): + op = binary.operator == operator.eq and "IN" or "NOT IN" + return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + if self.isinsert or self.isupdate: + target = stmt.table.alias("inserted") + else: + target = stmt.table.alias("deleted") + + adapter = sql_util.ClauseAdapter(target) + def col_label(col): + adapted = adapter.traverse(c) + if isinstance(c, expression._Label): + return adapted.label(c.key) + else: + return self.label_select_column(None, adapted, asfrom=False) + + columns = [ + self.process( + col_label(c), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) + ] + return 'OUTPUT ' + ', '.join(columns) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + if not column.table: + raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") + + seq_col = column.table._autoincrement_column + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = getattr(column, 'sequence', None) + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(drop.element.table.name), + self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote) + ) + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + def _escape_identifier(self, value): + #TODO: determine MSSQL's escaping rules + return value + + def quote_schema(self, schema, force=True): + """Prepare a quoted table and schema name.""" + result = '.'.join([self.quote(x, force) for x in schema.split('.')]) + return result + +class MSDialect(default.DefaultDialect): + name = 'mssql' + supports_default_values = True + supports_empty_insert = False + execution_ctx_cls = MSExecutionContext + text_as_varchar = False + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + colspecs = colspecs + ischema_names = ischema_names + + supports_unicode_binds = True + postfetch_lastrowid = True + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler = MSTypeCompiler + preparer = MSIdentifierPreparer + + def __init__(self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", **opts): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.max_identifier_length = int(max_identifier_length or 0) or \ + self.max_identifier_length + super(MSDialect, self).__init__(**opts) + + def do_savepoint(self, connection, name): + util.warn("Savepoint support in mssql is experimental and may lead to data loss.") + connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + connection.execute("SAVE TRANSACTION %s" % name) + + def do_release_savepoint(self, connection, name): + pass + + def initialize(self, connection): + super(MSDialect, self).initialize(connection) + if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__: + self.implicit_returning = True + + def get_default_schema_name(self, connection): + return self.default_schema_name + + def _get_default_schema_name(self, connection): + user_name = connection.scalar("SELECT user_name() as user_name;") + if user_name is not None: + # now, get the default schema + query = """ + SELECT default_schema_name FROM + sys.database_principals + WHERE name = ? + AND type = 'S' + """ + try: + default_schema_name = connection.scalar(query, [user_name]) + if default_schema_name is not None: + return default_schema_name + except: + pass + return self.schema_name + + def table_names(self, connection, schema): + s = select([ischema.tables.c.table_name], ischema.tables.c.table_schema==schema) + return [row[0] for row in connection.execute(s)] + + + def has_table(self, connection, tablename, schema=None): + current_schema = schema or self.default_schema_name + columns = ischema.columns + s = sql.select([columns], + current_schema + and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) + or columns.c.table_name==tablename, + ) + + c = connection.execute(s) + row = c.fetchone() + return row is not None + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select([ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name] + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == 'BASE TABLE' + ), + order_by=[tables.c.table_name] + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == 'VIEW' + ), + order_by=[tables.c.table_name] + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + # The cursor reports it is closed after executing the sp. + @reflection.cache + def get_indexes(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + full_tname = "%s.%s" % (current_schema, tablename) + indexes = [] + s = sql.text("exec sp_helpindex '%s'" % full_tname) + rp = connection.execute(s) + if rp.closed: + # did not work for this setup. + return [] + for row in rp: + if 'primary key' not in row['index_description']: + indexes.append({ + 'name' : row['index_name'], + 'column_names' : [c.strip() for c in row['index_keys'].split(',')], + 'unique': 'unique' in row['index_description'] + }) + return indexes + + @reflection.cache + def get_view_definition(self, connection, viewname, schema=None, **kw): + current_schema = schema or self.default_schema_name + views = ischema.views + s = sql.select([views.c.view_definition], + sql.and_( + views.c.table_schema == current_schema, + views.c.table_name == viewname + ), + ) + rp = connection.execute(s) + if rp: + view_def = rp.scalar() + return view_def + + @reflection.cache + def get_columns(self, connection, tablename, schema=None, **kw): + # Get base columns + current_schema = schema or self.default_schema_name + columns = ischema.columns + s = sql.select([columns], + current_schema + and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) + or columns.c.table_name==tablename, + order_by=[columns.c.ordinal_position]) + c = connection.execute(s) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type, nullable, charlen, numericprec, numericscale, default, collation) = ( + row[columns.c.column_name], + row[columns.c.data_type], + row[columns.c.is_nullable] == 'YES', + row[columns.c.character_maximum_length], + row[columns.c.numeric_precision], + row[columns.c.numeric_scale], + row[columns.c.column_default], + row[columns.c.collation_name] + ) + coltype = self.ischema_names.get(type, None) + + kwargs = {} + if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary): + kwargs['length'] = charlen + if collation: + kwargs['collation'] = collation + if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1): + kwargs.pop('length') + + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % (type, name)) + coltype = sqltypes.NULLTYPE + + if issubclass(coltype, sqltypes.Numeric) and coltype is not MSReal: + kwargs['scale'] = numericscale + kwargs['precision'] = numericprec + + coltype = coltype(**kwargs) + cdict = { + 'name' : name, + 'type' : coltype, + 'nullable' : nullable, + 'default' : default, + 'autoincrement':False, + } + cols.append(cdict) + # autoincrement and identity + colmap = {} + for col in cols: + colmap[col['name']] = col + # We also run an sp_columns to check for identity columns: + cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (tablename, current_schema)) + ic = None + while True: + row = cursor.fetchone() + if row is None: + break + (col_name, type_name) = row[3], row[5] + if type_name.endswith("identity") and col_name in colmap: + ic = col_name + colmap[col_name]['autoincrement'] = True + colmap[col_name]['sequence'] = dict( + name='%s_identity' % col_name) + break + cursor.close() + if ic is not None: + try: + # is this table_fullname reliable? + table_fullname = "%s.%s" % (current_schema, tablename) + cursor = connection.execute( + sql.text("select ident_seed(:seed), ident_incr(:incr)"), + {'seed':table_fullname, 'incr':table_fullname} + ) + row = cursor.fetchone() + cursor.close() + if not row is None: + colmap[ic]['sequence'].update({ + 'start' : int(row[0]), + 'increment' : int(row[1]) + }) + except: + # ignoring it, works just like before + pass + return cols + + @reflection.cache + def get_primary_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + pkeys = [] + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column + + # Primary key constraints + s = sql.select([C.c.column_name, TC.c.constraint_type], + sql.and_(TC.c.constraint_name == C.c.constraint_name, + C.c.table_name == tablename, + C.c.table_schema == current_schema) + ) + c = connection.execute(s) + for row in c: + if 'PRIMARY' in row[TC.c.constraint_type.name]: + pkeys.append(row[0]) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column + + # Foreign key constraints + s = sql.select([C.c.column_name, + R.c.table_schema, R.c.table_name, R.c.column_name, + RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule], + sql.and_(C.c.table_name == tablename, + C.c.table_schema == current_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + C.c.ordinal_position == R.c.ordinal_position + ), + order_by = [RR.c.constraint_name, R.c.ordinal_position]) + + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + fknm, scols, rcols = (None, [], []) + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).fetchall(): + scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r + + rec = fkeys[rfknm] + rec['name'] = rfknm + if not rec['referred_table']: + rec['referred_table'] = rtbl + + if schema is not None or current_schema != rschema: + rec['referred_schema'] = rschema + + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + local_cols.append(scol) + remote_cols.append(rcol) + + return fkeys.values() + + +# fixme. I added this for the tests to run. -Randall +MSSQLDialect = MSDialect diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py new file mode 100644 index 000000000..bb6ff315a --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -0,0 +1,83 @@ +from sqlalchemy import Table, MetaData, Column, ForeignKey +from sqlalchemy.types import String, Unicode, Integer, TypeDecorator + +ischema = MetaData() + +class CoerceUnicode(TypeDecorator): + impl = Unicode + + def process_bind_param(self, value, dialect): + if isinstance(value, str): + value = value.decode(dialect.encoding) + return value + +schemata = Table("SCHEMATA", ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA") + +tables = Table("TABLES", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", String, key="table_type"), + schema="INFORMATION_SCHEMA") + +columns = Table("COLUMNS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA") + +constraints = Table("TABLE_CONSTRAINTS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", String, key="constraint_type"), + schema="INFORMATION_SCHEMA") + +column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA") + +key_constraints = Table("KEY_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA") + +ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ? + Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"), + Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA") + +views = Table("VIEWS", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA") + diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 000000000..0961c2e76 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,46 @@ +from sqlalchemy.dialects.mssql.base import MSDialect +from sqlalchemy import types as sqltypes + + +class MSDialect_pymssql(MSDialect): + supports_sane_rowcount = False + max_identifier_length = 30 + driver = 'pymssql' + + @classmethod + def import_dbapi(cls): + import pymssql as module + # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal + module.Binary = lambda st: str(st) + return module + + def __init__(self, **params): + super(MSSQLDialect_pymssql, self).__init__(**params) + self.use_scope_identity = True + + # pymssql understands only ascii + if self.convert_unicode: + util.warn("pymssql does not support unicode") + self.encoding = params.get('encoding', 'ascii') + + + def create_connect_args(self, url): + if hasattr(self, 'query_timeout'): + # ick, globals ? we might want to move this.... + self.dbapi._mssql.set_query_timeout(self.query_timeout) + + keys = url.query + if keys.get('port'): + # pymssql expects port as host:port, not a separate arg + keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) + del keys['port'] + return [[], keys] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) + + def do_begin(self, connection): + pass + +dialect = MSDialect_pymssql
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 000000000..9a2a9e4e7 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,79 @@ +from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy import types as sqltypes +import re +import sys + +class MSExecutionContext_pyodbc(MSExecutionContext): + _embedded_scope_identity = False + + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same statement. + + Background on why "scope_identity()" is preferable to "@@identity": + http://msdn.microsoft.com/en-us/library/ms190315.aspx + + Background on why we attempt to embed "scope_identity()" into the same + statement as the INSERT: + http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values? + + """ + + super(MSExecutionContext_pyodbc, self).pre_exec() + + # don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES" + if self._select_lastrowid and \ + self.dialect.use_scope_identity and \ + len(self.parameters[0]): + self._embedded_scope_identity = True + + self.statement += "; select scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with no data (due to triggers, etc.) + while True: + try: + # fetchall() ensures the cursor is consumed without closing it (FreeTDS particularly) + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error, e: + # no way around this - nextset() consumes the previous set + # so we need to just keep flipping + self.cursor.nextset() + + self._lastrowid = int(row[0]) + else: + super(MSExecutionContext_pyodbc, self).post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + execution_ctx_cls = MSExecutionContext_pyodbc + + pyodbc_driver_name = 'SQL Server' + + def __init__(self, description_encoding='latin-1', **params): + super(MSDialect_pyodbc, self).__init__(**params) + self.description_encoding = description_encoding + self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') + + def initialize(self, connection): + super(MSDialect_pyodbc, self).initialize(connection) + pyodbc = self.dbapi + + dbapi_con = connection.connection + + self._free_tds = re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)) + + # the "Py2K only" part here is theoretical. + # have not tried pyodbc + python3.1 yet. + # Py2K + self.supports_unicode_statements = not self._free_tds + self.supports_unicode_binds = not self._free_tds + # end Py2K + +dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py new file mode 100644 index 000000000..4106a299b --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.mysql import base, mysqldb, pyodbc, zxjdbc + +# default dialect +base.dialect = mysqldb.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py new file mode 100644 index 000000000..1c5c251e5 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -0,0 +1,2545 @@ +# -*- fill-column: 78 -*- +# mysql.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the MySQL database. + +Overview +-------- + +For normal SQLAlchemy usage, importing this module is unnecessary. It will be +loaded on-demand when a MySQL connection is needed. The generic column types +like :class:`~sqlalchemy.String` and :class:`~sqlalchemy.Integer` will +automatically be adapted to the optimal matching MySQL column type. + +But if you would like to use one of the MySQL-specific or enhanced column +types when creating tables with your :class:`~sqlalchemy.Table` definitions, +then you will need to import them from this module:: + + from sqlalchemy.dialect.mysql import base as mysql + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('ittybittyblob', mysql.TINYBLOB), + Column('biggy', mysql.BIGINT(unsigned=True))) + +All standard MySQL column types are supported. The OpenGIS types are +available for use via table reflection but have no special support or mapping +to Python classes. If you're using these types and have opinions about how +OpenGIS can be smartly integrated into SQLAlchemy please join the mailing +list! + +Supported Versions and Features +------------------------------- + +SQLAlchemy supports 6 major MySQL versions: 3.23, 4.0, 4.1, 5.0, 5.1 and 6.0, +with capabilities increasing with more modern servers. + +Versions 4.1 and higher support the basic SQL functionality that SQLAlchemy +uses in the ORM and SQL expressions. These versions pass the applicable tests +in the suite 100%. No heroic measures are taken to work around major missing +SQL features- if your server version does not support sub-selects, for +example, they won't work in SQLAlchemy either. + +Currently, the only DB-API driver supported is `MySQL-Python` (also referred to +as `MySQLdb`). Either 1.2.1 or 1.2.2 are recommended. The alpha, beta and +gamma releases of 1.2.1 and 1.2.2 should be avoided. Support for Jython and +IronPython is planned. + +===================================== =============== +Feature Minimum Version +===================================== =============== +sqlalchemy.orm 4.1.1 +Table Reflection 3.23.x +DDL Generation 4.1.1 +utf8/Full Unicode Connections 4.1.1 +Transactions 3.23.15 +Two-Phase Transactions 5.0.3 +Nested Transactions 5.0.3 +===================================== =============== + +See the official MySQL documentation for detailed information about features +supported in any given server release. + +Storage Engines +--------------- + +Most MySQL server installations have a default table type of ``MyISAM``, a +non-transactional table type. During a transaction, non-transactional storage +engines do not participate and continue to store table changes in autocommit +mode. For fully atomic transactions, all participating tables must use a +transactional engine such as ``InnoDB``, ``Falcon``, ``SolidDB``, `PBXT`, etc. + +Storage engines can be elected when creating tables in SQLAlchemy by supplying +a ``mysql_engine='whatever'`` to the ``Table`` constructor. Any MySQL table +creation option can be specified in this syntax:: + + Table('mytable', metadata, + Column('data', String(32)), + mysql_engine='InnoDB', + mysql_charset='utf8' + ) + +Keys +---- + +Not all MySQL storage engines support foreign keys. For ``MyISAM`` and +similar engines, the information loaded by table reflection will not include +foreign keys. For these tables, you may supply a +:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: + + Table('mytable', metadata, + ForeignKeyConstraint(['other_id'], ['othertable.other_id']), + autoload=True + ) + +When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT``` on +an integer primary key column:: + + >>> t = Table('mytable', metadata, + ... Column('mytable_id', Integer, primary_key=True) + ... ) + >>> t.create() + CREATE TABLE mytable ( + id INTEGER NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + ) + +You can disable this behavior by supplying ``autoincrement=False`` to the +:class:`~sqlalchemy.Column`. This flag can also be used to enable +auto-increment on a secondary column in a multi-column key for some storage +engines:: + + Table('mytable', metadata, + Column('gid', Integer, primary_key=True, autoincrement=False), + Column('id', Integer, primary_key=True) + ) + +SQL Mode +-------- + +MySQL SQL modes are supported. Modes that enable ``ANSI_QUOTES`` (such as +``ANSI``) require an engine option to modify SQLAlchemy's quoting style. +When using an ANSI-quoting mode, supply ``use_ansiquotes=True`` when +creating your ``Engine``:: + + create_engine('mysql://localhost/test', use_ansiquotes=True) + +This is an engine-wide option and is not toggleable on a per-connection basis. +SQLAlchemy does not presume to ``SET sql_mode`` for you with this option. For +the best performance, set the quoting style server-wide in ``my.cnf`` or by +supplying ``--sql-mode`` to ``mysqld``. You can also use a +:class:`sqlalchemy.pool.Pool` listener hook to issue a ``SET SESSION +sql_mode='...'`` on connect to configure each connection. + +If you do not specify ``use_ansiquotes``, the regular MySQL quoting style is +used by default. + +If you do issue a ``SET sql_mode`` through SQLAlchemy, the dialect must be +updated if the quoting style is changed. Again, this change will affect all +connections:: + + connection.execute('SET sql_mode="ansi"') + connection.dialect.use_ansiquotes = True + +MySQL SQL Extensions +-------------------- + +Many of the MySQL SQL extensions are handled through SQLAlchemy's generic +function and operator support:: + + table.select(table.c.password==func.md5('plaintext')) + table.select(table.c.username.op('regexp')('^[a-d]')) + +And of course any valid MySQL statement can be executed as a string as well. + +Some limited direct support for MySQL extensions to SQL is currently +available. + + * SELECT pragma:: + + select(..., prefixes=['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) + + * UPDATE with LIMIT:: + + update(..., mysql_limit=10) + +Troubleshooting +--------------- + +If you have problems that seem server related, first check that you are +using the most recent stable MySQL-Python package available. The Database +Notes page on the wiki at http://www.sqlalchemy.org is a good resource for +timely information affecting MySQL in SQLAlchemy. + +""" + +import datetime, inspect, re, sys + +from sqlalchemy import schema as sa_schema +from sqlalchemy import exc, log, sql, util +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy.sql import functions as sql_functions +from sqlalchemy.sql import compiler +from array import array as _array + +from sqlalchemy.engine import reflection +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import types as sqltypes + +from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME + +RESERVED_WORDS = set( + ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', + 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', + 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', + 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', + 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', + 'current_user', 'cursor', 'database', 'databases', 'day_hour', + 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', + 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', + 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', + 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', + 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', + 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', 'having', + 'high_priority', 'hour_microsecond', 'hour_minute', 'hour_second', 'if', + 'ignore', 'in', 'index', 'infile', 'inner', 'inout', 'insensitive', + 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8', 'integer', + 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys', 'kill', + 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines', 'load', + 'localtime', 'localtimestamp', 'lock', 'long', 'longblob', 'longtext', + 'loop', 'low_priority', 'master_ssl_verify_server_cert', 'match', + 'mediumblob', 'mediumint', 'mediumtext', 'middleint', + 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural', + 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize', + 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile', + 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads', + 'read_only', 'read_write', 'real', 'references', 'regexp', 'release', + 'rename', 'repeat', 'replace', 'require', 'restrict', 'return', + 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond', + 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial', + 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning', + 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl', + 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob', + 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo', + 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', + 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', + 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', + 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 + 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1 + 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', + 'read_only', 'read_write', # 5.1 + ]) + +AUTOCOMMIT_RE = re.compile( + r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', + re.I | re.UNICODE) +SET_RE = re.compile( + r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', + re.I | re.UNICODE) + + +class _NumericType(object): + """Base for MySQL numeric types.""" + + def __init__(self, **kw): + self.unsigned = kw.pop('unsigned', False) + self.zerofill = kw.pop('zerofill', False) + super(_NumericType, self).__init__(**kw) + +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + if isinstance(self, (REAL, DOUBLE)) and \ + ( + (precision is None and scale is not None) or + (precision is not None and scale is None) + ): + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether.") + + super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw) + self.scale = scale + +class _IntegerType(_NumericType, sqltypes.Integer): + def __init__(self, display_width=None, **kw): + self.display_width = display_width + super(_IntegerType, self).__init__(**kw) + +class _StringType(sqltypes.String): + """Base for MySQL string types.""" + + def __init__(self, charset=None, collation=None, + ascii=False, unicode=False, binary=False, + national=False, **kw): + self.charset = charset + # allow collate= or collation= + self.collation = kw.pop('collate', collation) + self.ascii = ascii + self.unicode = unicode + self.binary = binary + self.national = national + super(_StringType, self).__init__(**kw) + + def __repr__(self): + attributes = inspect.getargspec(self.__init__)[0][1:] + attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) + + params = {} + for attr in attributes: + val = getattr(self, attr) + if val is not None and val is not False: + params[attr] = val + + return "%s(%s)" % (self.__class__.__name__, + ', '.join(['%s=%r' % (k, params[k]) for k in params])) + + +class _BinaryType(sqltypes.Binary): + """Base for MySQL binary types.""" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return util.buffer(value) + return process + +class NUMERIC(_NumericType, sqltypes.NUMERIC): + """MySQL NUMERIC type.""" + + __visit_name__ = 'NUMERIC' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a NUMERIC. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(NUMERIC, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + + +class DECIMAL(_NumericType, sqltypes.DECIMAL): + """MySQL DECIMAL type.""" + + __visit_name__ = 'DECIMAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DECIMAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(DECIMAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + + +class DOUBLE(_FloatType): + """MySQL DOUBLE type.""" + + __visit_name__ = 'DOUBLE' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DOUBLE. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(DOUBLE, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + +class REAL(_FloatType): + """MySQL REAL type.""" + + __visit_name__ = 'REAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a REAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(REAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + +class FLOAT(_FloatType, sqltypes.FLOAT): + """MySQL FLOAT type.""" + + __visit_name__ = 'FLOAT' + + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): + """Construct a FLOAT. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(FLOAT, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + + def bind_processor(self, dialect): + return None + +class INTEGER(_IntegerType, sqltypes.INTEGER): + """MySQL INTEGER type.""" + + __visit_name__ = 'INTEGER' + + def __init__(self, display_width=None, **kw): + """Construct an INTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(INTEGER, self).__init__(display_width=display_width, **kw) + +class BIGINT(_IntegerType, sqltypes.BIGINT): + """MySQL BIGINTEGER type.""" + + __visit_name__ = 'BIGINT' + + def __init__(self, display_width=None, **kw): + """Construct a BIGINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(BIGINT, self).__init__(display_width=display_width, **kw) + +class MEDIUMINT(_IntegerType): + """MySQL MEDIUMINTEGER type.""" + + __visit_name__ = 'MEDIUMINT' + + def __init__(self, display_width=None, **kw): + """Construct a MEDIUMINTEGER + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(MEDIUMINT, self).__init__(display_width=display_width, **kw) + +class TINYINT(_IntegerType): + """MySQL TINYINT type.""" + + __visit_name__ = 'TINYINT' + + def __init__(self, display_width=None, **kw): + """Construct a TINYINT. + + Note: following the usual MySQL conventions, TINYINT(1) columns + reflected during Table(..., autoload=True) are treated as + Boolean columns. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(TINYINT, self).__init__(display_width=display_width, **kw) + +class SMALLINT(_IntegerType, sqltypes.SMALLINT): + """MySQL SMALLINTEGER type.""" + + __visit_name__ = 'SMALLINT' + + def __init__(self, display_width=None, **kw): + """Construct a SMALLINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(SMALLINT, self).__init__(display_width=display_width, **kw) + +class BIT(sqltypes.TypeEngine): + """MySQL BIT type. + + This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater for + MyISAM, MEMORY, InnoDB and BDB. For older versions, use a MSTinyInteger() + type. + + """ + + __visit_name__ = 'BIT' + + def __init__(self, length=None): + """Construct a BIT. + + :param length: Optional, number of bits. + + """ + self.length = length + + def result_processor(self, dialect): + """Convert a MySQL's 64 bit, variable length binary string to a long.""" + def process(value): + if value is not None: + v = 0L + for i in map(ord, value): + v = v << 8 | i + value = v + return value + return process + +class _MSTime(sqltypes.Time): + """MySQL TIME type.""" + + __visit_name__ = 'TIME' + + def result_processor(self, dialect): + def process(value): + # convert from a timedelta value + if value is not None: + return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) + else: + return None + return process + +class TIMESTAMP(sqltypes.TIMESTAMP): + """MySQL TIMESTAMP type.""" + __visit_name__ = 'TIMESTAMP' + +class YEAR(sqltypes.TypeEngine): + """MySQL YEAR type, for single byte storage of years 1901-2155.""" + + __visit_name__ = 'YEAR' + + def __init__(self, display_width=None): + self.display_width = display_width + +class TEXT(_StringType, sqltypes.TEXT): + """MySQL TEXT type, for text up to 2^16 characters.""" + + __visit_name__ = 'TEXT' + + def __init__(self, length=None, **kw): + """Construct a TEXT. + + :param length: Optional, if provided the server may optimize storage + by substituting the smallest TEXT type sufficient to store + ``length`` characters. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(TEXT, self).__init__(length=length, **kw) + +class TINYTEXT(_StringType): + """MySQL TINYTEXT type, for text up to 2^8 characters.""" + + __visit_name__ = 'TINYTEXT' + + def __init__(self, **kwargs): + """Construct a TINYTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(TINYTEXT, self).__init__(**kwargs) + +class MEDIUMTEXT(_StringType): + """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" + + __visit_name__ = 'MEDIUMTEXT' + + def __init__(self, **kwargs): + """Construct a MEDIUMTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(MEDIUMTEXT, self).__init__(**kwargs) + +class LONGTEXT(_StringType): + """MySQL LONGTEXT type, for text up to 2^32 characters.""" + + __visit_name__ = 'LONGTEXT' + + def __init__(self, **kwargs): + """Construct a LONGTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(LONGTEXT, self).__init__(**kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MySQL VARCHAR type, for variable-length character data.""" + + __visit_name__ = 'VARCHAR' + + def __init__(self, length=None, **kwargs): + """Construct a VARCHAR. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(VARCHAR, self).__init__(length=length, **kwargs) + +class CHAR(_StringType, sqltypes.CHAR): + """MySQL CHAR type, for fixed-length character data.""" + + __visit_name__ = 'CHAR' + + def __init__(self, length, **kwargs): + """Construct a CHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + super(CHAR, self).__init__(length=length, **kwargs) + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MySQL NVARCHAR type. + + For variable-length character data in the server's configured national + character set. + """ + + __visit_name__ = 'NVARCHAR' + + def __init__(self, length=None, **kwargs): + """Construct an NVARCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs['national'] = True + super(NVARCHAR, self).__init__(length=length, **kwargs) + + +class NCHAR(_StringType, sqltypes.NCHAR): + """MySQL NCHAR type. + + For fixed-length character data in the server's configured national + character set. + """ + + __visit_name__ = 'NCHAR' + + def __init__(self, length=None, **kwargs): + """Construct an NCHAR. Arguments are: + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs['national'] = True + super(NCHAR, self).__init__(length=length, **kwargs) + + + +class VARBINARY(_BinaryType): + """MySQL VARBINARY type, for variable length binary data.""" + + __visit_name__ = 'VARBINARY' + + def __init__(self, length=None, **kw): + """Construct a VARBINARY. Arguments are: + + :param length: Maximum data length, in characters. + + """ + super(VARBINARY, self).__init__(length=length, **kw) + +class BINARY(_BinaryType): + """MySQL BINARY type, for fixed length binary data""" + + __visit_name__ = 'BINARY' + + def __init__(self, length=None, **kw): + """Construct a BINARY. + + This is a fixed length type, and short values will be right-padded + with a server-version-specific pad value. + + :param length: Maximum data length, in bytes. If length is not + specified, this will generate a BLOB. This usage is deprecated. + + """ + super(BINARY, self).__init__(length=length, **kw) + +class BLOB(_BinaryType, sqltypes.BLOB): + """MySQL BLOB type, for binary data up to 2^16 bytes""" + + __visit_name__ = 'BLOB' + + def __init__(self, length=None, **kw): + """Construct a BLOB. Arguments are: + + :param length: Optional, if provided the server may optimize storage + by substituting the smallest TEXT type sufficient to store + ``length`` characters. + + """ + super(BLOB, self).__init__(length=length, **kw) + + +class TINYBLOB(_BinaryType): + """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" + + __visit_name__ = 'TINYBLOB' + +class MEDIUMBLOB(_BinaryType): + """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" + + __visit_name__ = 'MEDIUMBLOB' + +class LONGBLOB(_BinaryType): + """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" + + __visit_name__ = 'LONGBLOB' + +class ENUM(_StringType): + """MySQL ENUM type.""" + + __visit_name__ = 'ENUM' + + def __init__(self, *enums, **kw): + """Construct an ENUM. + + Example: + + Column('myenum', MSEnum("foo", "bar", "baz")) + + Arguments are: + + :param enums: The range of valid values for this ENUM. Values will be + quoted when generating the schema according to the quoting flag (see + below). + + :param strict: Defaults to False: ensure that a given value is in this + ENUM's range of permissible values when inserting or updating rows. + Note that MySQL will not raise a fatal error if you attempt to store + an out of range value- an alternate value will be stored instead. + (See MySQL ENUM documentation.) + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + :param quoting: Defaults to 'auto': automatically determine enum value + quoting. If all enum values are surrounded by the same quoting + character, then use 'quoted' mode. Otherwise, use 'unquoted' mode. + + 'quoted': values in enums are already quoted, they will be used + directly when generating the schema. + + 'unquoted': values in enums are not quoted, they will be escaped and + surrounded by single quotes when generating the schema. + + Previous versions of this type always required manually quoted + values to be supplied; future versions will always quote the string + literals for you. This is a transitional option. + + """ + self.quoting = kw.pop('quoting', 'auto') + + if self.quoting == 'auto': + # What quoting character are we using? + q = None + for e in enums: + if len(e) == 0: + self.quoting = 'unquoted' + break + elif q is None: + q = e[0] + + if e[0] != q or e[-1] != q: + self.quoting = 'unquoted' + break + else: + self.quoting = 'quoted' + + if self.quoting == 'quoted': + util.warn_pending_deprecation( + 'Manually quoting ENUM value literals is deprecated. Supply ' + 'unquoted values and use the quoting= option in cases of ' + 'ambiguity.') + strip_enums = [] + for a in enums: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_enums.append(a) + self.enums = strip_enums + else: + self.enums = list(enums) + + self.strict = kw.pop('strict', False) + length = max([len(v) for v in self.enums] + [0]) + super(ENUM, self).__init__(length=length, **kw) + + def bind_processor(self, dialect): + super_convert = super(ENUM, self).bind_processor(dialect) + def process(value): + if self.strict and value is not None and value not in self.enums: + raise exc.InvalidRequestError('"%s" not a valid value for ' + 'this enum' % value) + if super_convert: + return super_convert(value) + else: + return value + return process + +class SET(_StringType): + """MySQL SET type.""" + + __visit_name__ = 'SET' + + def __init__(self, *values, **kw): + """Construct a SET. + + Example:: + + Column('myset', MSSet("'foo'", "'bar'", "'baz'")) + + Arguments are: + + :param values: The range of valid values for this SET. Values will be + used exactly as they appear when generating schemas. Strings must + be quoted, as in the example above. Single-quotes are suggested for + ANSI compatibility and are required for portability to servers with + ANSI_QUOTES enabled. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + self._ddl_values = values + + strip_values = [] + for a in values: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_values.append(a) + + self.values = strip_values + length = max([len(v) for v in strip_values] + [0]) + super(SET, self).__init__(length=length, **kw) + + def result_processor(self, dialect): + def process(value): + # The good news: + # No ',' quoting issues- commas aren't allowed in SET values + # The bad news: + # Plenty of driver inconsistencies here. + if isinstance(value, util.set_types): + # ..some versions convert '' to an empty set + if not value: + value.add('') + # ..some return sets.Set, even for pythons that have __builtin__.set + if not isinstance(value, set): + value = set(value) + return value + # ...and some versions return strings + if value is not None: + return set(value.split(',')) + else: + return value + return process + + def bind_processor(self, dialect): + super_convert = super(SET, self).bind_processor(dialect) + def process(value): + if value is None or isinstance(value, (int, long, basestring)): + pass + else: + if None in value: + value = set(value) + value.remove(None) + value.add('') + value = ','.join(value) + if super_convert: + return super_convert(value) + else: + return value + return process + +class _MSBoolean(sqltypes.Boolean): + """MySQL BOOLEAN type.""" + + __visit_name__ = 'BOOLEAN' + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +# old names +MSBoolean = _MSBoolean +MSTime = _MSTime +MSSet = SET +MSEnum = ENUM +MSLongBlob = LONGBLOB +MSMediumBlob = MEDIUMBLOB +MSTinyBlob = TINYBLOB +MSBlob = BLOB +MSBinary = BINARY +MSVarBinary = VARBINARY +MSNChar = NCHAR +MSNVarChar = NVARCHAR +MSChar = CHAR +MSString = VARCHAR +MSLongText = LONGTEXT +MSMediumText = MEDIUMTEXT +MSTinyText = TINYTEXT +MSText = TEXT +MSYear = YEAR +MSTimeStamp = TIMESTAMP +MSBit = BIT +MSSmallInteger = SMALLINT +MSTinyInteger = TINYINT +MSMediumInteger = MEDIUMINT +MSBigInteger = BIGINT +MSNumeric = NUMERIC +MSDecimal = DECIMAL +MSDouble = DOUBLE +MSReal = REAL +MSFloat = FLOAT +MSInteger = INTEGER + +colspecs = { + sqltypes.Numeric: NUMERIC, + sqltypes.Float: FLOAT, + sqltypes.Binary: _BinaryType, + sqltypes.Boolean: _MSBoolean, + sqltypes.Time: _MSTime, +} + +# Everything 3.23 through 5.1 excepting OpenGIS types. +ischema_names = { + 'bigint': BIGINT, + 'binary': BINARY, + 'bit': BIT, + 'blob': BLOB, + 'boolean':BOOLEAN, + 'char': CHAR, + 'date': DATE, + 'datetime': DATETIME, + 'decimal': DECIMAL, + 'double': DOUBLE, + 'enum': ENUM, + 'fixed': DECIMAL, + 'float': FLOAT, + 'int': INTEGER, + 'integer': INTEGER, + 'longblob': LONGBLOB, + 'longtext': LONGTEXT, + 'mediumblob': MEDIUMBLOB, + 'mediumint': MEDIUMINT, + 'mediumtext': MEDIUMTEXT, + 'nchar': NCHAR, + 'nvarchar': NVARCHAR, + 'numeric': NUMERIC, + 'set': SET, + 'smallint': SMALLINT, + 'text': TEXT, + 'time': TIME, + 'timestamp': TIMESTAMP, + 'tinyblob': TINYBLOB, + 'tinyint': TINYINT, + 'tinytext': TINYTEXT, + 'varbinary': VARBINARY, + 'varchar': VARCHAR, + 'year': YEAR, +} + +class MySQLExecutionContext(default.DefaultExecutionContext): + def post_exec(self): + # TODO: i think this 'charset' in the info thing + # is out + + if (not self.isupdate and not self.should_autocommit and + self.statement and SET_RE.match(self.statement)): + # This misses if a user forces autocommit on text('SET NAMES'), + # which is probably a programming error anyhow. + self.connection.info.pop(('mysql', 'charset'), None) + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_RE.match(statement) + +class MySQLCompiler(compiler.SQLCompiler): + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'milliseconds': 'millisecond', + }) + + def visit_random_func(self, fn, **kw): + return "rand%s" % self.function_argspec(fn) + + def visit_utc_timestamp_func(self, fn, **kw): + return "UTC_TIMESTAMP" + + def visit_concat_op(self, binary, **kw): + return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary, **kw): + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right)) + + def visit_typeclause(self, typeclause): + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.Integer): + if getattr(type_, 'unsigned', False): + return 'UNSIGNED INTEGER' + else: + return 'SIGNED INTEGER' + elif isinstance(type_, sqltypes.TIMESTAMP): + return 'DATETIME' + elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, sqltypes.Date, sqltypes.Time)): + return self.dialect.type_compiler.process(type_) + elif isinstance(type_, sqltypes.Text): + return 'CHAR' + elif (isinstance(type_, sqltypes.String) and not + isinstance(type_, (ENUM, SET))): + if getattr(type_, 'length'): + return 'CHAR(%s)' % type_.length + else: + return 'CHAR' + elif isinstance(type_, sqltypes.Binary): + return 'BINARY' + elif isinstance(type_, NUMERIC): + return self.dialect.type_compiler.process(type_).replace('NUMERIC', 'DECIMAL') + else: + return None + + def visit_cast(self, cast, **kwargs): + # No cast until 4, no decimals until 5. + type_ = self.process(cast.typeclause) + if type_ is None: + return self.process(cast.clause) + + return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) + + def get_select_precolumns(self, select): + if isinstance(select._distinct, basestring): + return select._distinct.upper() + " " + elif select._distinct: + return "DISTINCT " + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + # 'JOIN ... ON ...' for inner joins isn't available until 4.0. + # Apparently < 3.23.17 requires theta joins for inner joins + # (but not outer). Not generating these currently, but + # support can be added, preferably after dialects are + # refactored to be version-sensitive. + return ''.join( + (self.process(join.left, asfrom=True), + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), + self.process(join.right, asfrom=True), + " ON ", + self.process(join.onclause))) + + def for_update_clause(self, select): + if select.for_update == 'read': + return ' LOCK IN SHARE MODE' + else: + return super(MySQLCompiler, self).for_update_clause(select) + + def limit_clause(self, select): + # MySQL supports: + # LIMIT <limit> + # LIMIT <offset>, <limit> + # and in server versions > 3.3: + # LIMIT <limit> OFFSET <offset> + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. + + limit, offset = select._limit, select._offset + + if (limit, offset) == (None, None): + return '' + elif offset is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + if limit is None: + limit = 18446744073709551615 + return ' \n LIMIT %s, %s' % (offset, limit) + else: + # No offset provided, so just use the limit + return ' \n LIMIT %s' % (limit,) + + def visit_update(self, update_stmt): + self.stack.append({'from': set([update_stmt.table])}) + + self.isupdate = True + colparams = self._get_colparams(update_stmt) + + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \ + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) + + if update_stmt._whereclause: + text += " WHERE " + self.process(update_stmt._whereclause) + + limit = update_stmt.kwargs.get('mysql_limit', None) + if limit: + text += " LIMIT %s" % limit + + self.stack.pop(-1) + + return text + +# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. +# Starting with MySQL 4.1.2, these indexes are created automatically. +# In older versions, the indexes must be created explicitly or the +# creation of foreign key constraints fails." + +class MySQLDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kw): + """Builds column DDL.""" + + colspec = [self.preparer.format_column(column), + self.dialect.type_compiler.process(column.type) + ] + + default = self.get_column_default_string(column) + if default is not None: + colspec.append('DEFAULT ' + default) + + if not column.nullable: + colspec.append('NOT NULL') + + if column.primary_key and column.autoincrement: + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + isinstance(c.type, sqltypes.Integer) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('AUTO_INCREMENT') + except IndexError: + pass + + return ' '.join(colspec) + + def post_create_table(self, table): + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] + for k in table.kwargs: + if k.startswith('mysql_'): + opt = k[6:].upper() + joiner = '=' + if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', + 'CHARACTER SET', 'COLLATE'): + joiner = ' ' + + table_opts.append(joiner.join((opt, table.kwargs[k]))) + return ' '.join(table_opts) + + def visit_drop_index(self, drop): + index = drop.element + + return "\nDROP INDEX %s ON %s" % \ + (self.preparer.quote(self._validate_identifier(index.name, False), index.quote), + self.preparer.format_table(index.table)) + + def visit_drop_constraint(self, drop): + constraint = drop.element + if isinstance(constraint, sa_schema.ForeignKeyConstraint): + qual = "FOREIGN KEY " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): + qual = "PRIMARY KEY " + const = "" + elif isinstance(constraint, sa_schema.UniqueConstraint): + qual = "INDEX " + const = self.preparer.format_constraint(constraint) + else: + qual = "" + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % \ + (self.preparer.format_table(constraint.table), + qual, const) + +class MySQLTypeCompiler(compiler.GenericTypeCompiler): + def _extend_numeric(self, type_, spec): + "Extend a numeric-type declaration with MySQL specific extensions." + + if not self._mysql_type(type_): + return spec + + if type_.unsigned: + spec += ' UNSIGNED' + if type_.zerofill: + spec += ' ZEROFILL' + return spec + + def _extend_string(self, type_, defaults, spec): + """Extend a string-type declaration with standard SQL CHARACTER SET / + COLLATE annotations and MySQL specific extensions. + + """ + + def attr(name): + return getattr(type_, name, defaults.get(name)) + + if attr('charset'): + charset = 'CHARACTER SET %s' % attr('charset') + elif attr('ascii'): + charset = 'ASCII' + elif attr('unicode'): + charset = 'UNICODE' + else: + charset = None + + if attr('collation'): + collation = 'COLLATE %s' % type_.collation + elif attr('binary'): + collation = 'BINARY' + else: + collation = None + + if attr('national'): + # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. + return ' '.join([c for c in ('NATIONAL', spec, collation) + if c is not None]) + return ' '.join([c for c in (spec, charset, collation) + if c is not None]) + + def _mysql_type(self, type_): + return isinstance(type_, (_StringType, _NumericType, _BinaryType)) + + def visit_NUMERIC(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "NUMERIC") + elif type_.scale is None: + return self._extend_numeric(type_, "NUMERIC(%(precision)s)" % {'precision': type_.precision}) + else: + return self._extend_numeric(type_, "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) + + def visit_DECIMAL(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "DECIMAL") + elif type_.scale is None: + return self._extend_numeric(type_, "DECIMAL(%(precision)s)" % {'precision': type_.precision}) + else: + return self._extend_numeric(type_, "DECIMAL(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) + + def visit_DOUBLE(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale' : type_.scale}) + else: + return self._extend_numeric(type_, 'DOUBLE') + + def visit_REAL(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale' : type_.scale}) + else: + return self._extend_numeric(type_, 'REAL') + + def visit_FLOAT(self, type_): + if self._mysql_type(type_) and type_.scale is not None and type_.precision is not None: + return self._extend_numeric(type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + elif type_.precision is not None: + return self._extend_numeric(type_, "FLOAT(%s)" % (type_.precision,)) + else: + return self._extend_numeric(type_, "FLOAT") + + def visit_INTEGER(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "INTEGER(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "INTEGER") + + def visit_BIGINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "BIGINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "BIGINT") + + def visit_MEDIUMINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "MEDIUMINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "MEDIUMINT") + + def visit_TINYINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width) + else: + return self._extend_numeric(type_, "TINYINT") + + def visit_SMALLINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "SMALLINT") + + def visit_BIT(self, type_): + if type_.length is not None: + return "BIT(%s)" % type_.length + else: + return "BIT" + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_YEAR(self, type_): + if type_.display_width is None: + return "YEAR" + else: + return "YEAR(%s)" % type_.display_width + + def visit_TEXT(self, type_): + if type_.length: + return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "TEXT") + + def visit_TINYTEXT(self, type_): + return self._extend_string(type_, {}, "TINYTEXT") + + def visit_MEDIUMTEXT(self, type_): + return self._extend_string(type_, {}, "MEDIUMTEXT") + + def visit_LONGTEXT(self, type_): + return self._extend_string(type_, {}, "LONGTEXT") + + def visit_VARCHAR(self, type_): + if type_.length: + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "VARCHAR") + + def visit_CHAR(self, type_): + return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length' : type_.length}) + + def visit_NVARCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL VARCHAR" instead + # of "NVARCHAR". + return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length}) + + def visit_NCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". + return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length': type_.length}) + + def visit_VARBINARY(self, type_): + if type_.length: + return "VARBINARY(%d)" % type_.length + else: + return self.visit_BLOB(type_) + + def visit_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_BINARY(self, type_): + if type_.length: + return "BINARY(%d)" % type_.length + else: + return self.visit_BLOB(type_) + + def visit_BLOB(self, type_): + if type_.length: + return "BLOB(%d)" % type_.length + else: + return "BLOB" + + def visit_TINYBLOB(self, type_): + return "TINYBLOB" + + def visit_MEDIUMBLOB(self, type_): + return "MEDIUMBLOB" + + def visit_LONGBLOB(self, type_): + return "LONGBLOB" + + def visit_ENUM(self, type_): + quoted_enums = [] + for e in type_.enums: + quoted_enums.append("'%s'" % e.replace("'", "''")) + return self._extend_string(type_, {}, "ENUM(%s)" % ",".join(quoted_enums)) + + def visit_SET(self, type_): + return self._extend_string(type_, {}, "SET(%s)" % ",".join(type_._ddl_values)) + + def visit_BOOLEAN(self, type): + return "BOOL" + + +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. Not used directly in application code.""" + name = 'mysql' + supports_alter = True + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + default_paramstyle = 'format' + colspecs = colspecs + + statement_compiler = MySQLCompiler + ddl_compiler = MySQLDDLCompiler + type_compiler = MySQLTypeCompiler + ischema_names = ischema_names + + def __init__(self, use_ansiquotes=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + + def do_commit(self, connection): + """Execute a COMMIT.""" + + # COMMIT/ROLLBACK were introduced in 3.23.15. + # Yes, we have at least one user who has to talk to these old versions! + # + # Ignore commit/rollback if support isn't present, otherwise even basic + # operations via autocommit fail. + try: + connection.commit() + except: + if self.server_version_info < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + + def do_rollback(self, connection): + """Execute a ROLLBACK.""" + + try: + connection.rollback() + except: + if self.server_version_info < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid"), xid=xid) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid"), xid=xid) + connection.execute(sql.text("XA PREPARE :xid"), xid=xid) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + connection.execute(sql.text("XA END :xid"), xid=xid) + connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid) + + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid"), xid=xid) + + def do_recover_twophase(self, connection): + resultset = connection.execute("XA RECOVER") + return [row['data'][0:row['gtrid_length']] for row in resultset] + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return self._extract_error_code(e) in (2006, 2013, 2014, 2045, 2055) + elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, this is the error you get + return "(0, '')" in str(e) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" + + return [_DecodingRowProxy(row, charset) for row in rp.fetchall()] + + def _compat_fetchone(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" + + return _DecodingRowProxy(rp.fetchone(), charset) + + def _extract_error_code(self, exception): + raise NotImplementedError() + + def get_default_schema_name(self, connection): + return connection.execute('SELECT DATABASE()').scalar() + + def table_names(self, connection, schema): + """Return a Unicode SHOW TABLES from a given schema.""" + + charset = self._connection_charset + rp = connection.execute("SHOW TABLES FROM %s" % + self.identifier_preparer.quote_identifier(schema)) + return [row[0] for row in self._compat_fetchall(rp, charset=charset)] + + def has_table(self, connection, table_name, schema=None): + # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly + # on macosx (and maybe win?) with multibyte table names. + # + # TODO: if this is not a problem on win, make the strategy swappable + # based on platform. DESCRIBE is slower. + + # [ticket:726] + # full_name = self.identifier_preparer.format_table(table, + # use_schema=True) + + + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + + st = "DESCRIBE %s" % full_name + rs = None + try: + try: + rs = connection.execute(st) + have = rs.rowcount > 0 + rs.close() + return have + except exc.SQLError, e: + if self._extract_error_code(e) == 1146: + return False + raise + finally: + if rs: + rs.close() + + def initialize(self, connection): + self.server_version_info = self._get_server_version_info(connection) + self._connection_charset = self._detect_charset(connection) + self._server_casing = self._detect_casing(connection) + self._server_collations = self._detect_collations(connection) + self._server_ansiquotes = self._detect_ansiquotes(connection) + + if self._server_ansiquotes: + self.preparer = MySQLANSIIdentifierPreparer + else: + self.preparer = MySQLIdentifierPreparer + self.identifier_preparer = self.preparer(self) + + @reflection.cache + def get_schema_names(self, connection, **kw): + rp = connection.execute("SHOW schemas") + return [r[0] for r in rp] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.get_default_schema_name(connection) + if self.server_version_info < (5, 0, 2): + return self.table_names(connection, schema) + charset = self._connection_charset + rp = connection.execute("SHOW FULL TABLES FROM %s" % + self.identifier_preparer.quote_identifier(schema)) + + return [row[0] for row in self._compat_fetchall(rp, charset=charset)\ + if row[1] == 'BASE TABLE'] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + charset = self._connection_charset + if self.server_version_info < (5, 0, 2): + raise NotImplementedError + if schema is None: + schema = self.get_default_schema_name(connection) + if self.server_version_info < (5, 0, 2): + return self.table_names(connection, schema) + charset = self._connection_charset + rp = connection.execute("SHOW FULL TABLES FROM %s" % + self.identifier_preparer.quote_identifier(schema)) + return [row[0] for row in self._compat_fetchall(rp, charset=charset)\ + if row[1] == 'VIEW'] + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + return parsed_state.table_options + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + return parsed_state.columns + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + for key in parsed_state.keys: + if key['type'] == 'PRIMARY': + # There can be only one. + ##raise Exception, str(key) + return [s[0] for s in key['columns']] + return [] + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + default_schema = None + + fkeys = [] + + for spec in parsed_state.constraints: + # only FOREIGN KEYs + ref_name = spec['table'][-1] + ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema + + if not ref_schema: + if default_schema is None: + default_schema = \ + connection.dialect.get_default_schema_name(connection) + if schema == default_schema: + ref_schema = schema + + loc_names = spec['local'] + ref_names = spec['foreign'] + + con_kw = {} + for opt in ('name', 'onupdate', 'ondelete'): + if spec.get(opt, False): + con_kw[opt] = spec[opt] + + fkey_d = { + 'name' : spec['name'], + 'constrained_columns' : loc_names, + 'referred_schema' : ref_schema, + 'referred_table' : ref_name, + 'referred_columns' : ref_names, + 'options' : con_kw + } + fkeys.append(fkey_d) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + + indexes = [] + for spec in parsed_state.keys: + unique = False + flavor = spec['type'] + if flavor == 'PRIMARY': + continue + if flavor == 'UNIQUE': + unique = True + elif flavor in (None, 'FULLTEXT', 'SPATIAL'): + pass + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY" % flavor) + pass + index_d = {} + index_d['name'] = spec['name'] + index_d['column_names'] = [s[0] for s in spec['columns']] + index_d['unique'] = unique + index_d['type'] = flavor + indexes.append(index_d) + return indexes + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + + charset = self._connection_charset + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, view_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) + return sql + + def _parsed_state_or_create(self, connection, table_name, schema=None, **kw): + return self._setup_parser( + connection, + table_name, + schema, + info_cache=kw.get('info_cache', None) + ) + + @reflection.cache + def _setup_parser(self, connection, table_name, schema=None, **kw): + charset = self._connection_charset + try: + parser = self.parser + except AttributeError: + preparer = self.identifier_preparer + if (self.server_version_info < (4, 1) and + self._server_ansiquotes): + # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 + preparer = MySQLIdentifierPreparer(self) + self.parser = parser = MySQLTableDefinitionParser(self, preparer) + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) + if sql.startswith('CREATE ALGORITHM'): + # Adapt views to something table-like. + columns = self._describe_table(connection, None, charset, + full_name=full_name) + sql = parser._describe_to_create(table_name, columns) + return parser.parse(sql, charset) + + def _adjust_casing(self, table, charset=None): + """Adjust Table name to the server case sensitivity, if needed.""" + + casing = self._server_casing + + # For winxx database hosts. TODO: is this really needed? + if casing == 1 and table.name != table.name.lower(): + table.name = table.name.lower() + lc_alias = schema._get_table_key(table.name, table.schema) + table.metadata.tables[lc_alias] = table + + def _detect_charset(self, connection): + raise NotImplementedError() + + def _detect_casing(self, connection): + """Sniff out identifier case sensitivity. + + Cached per-connection. This value can not change without a server + restart. + + """ + # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html + + charset = self._connection_charset + row = self._compat_fetchone(connection.execute( + "SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset) + if not row: + cs = 0 + else: + # 4.0.15 returns OFF or ON according to [ticket:489] + # 3.23 doesn't, 4.0.27 doesn't.. + if row[1] == 'OFF': + cs = 0 + elif row[1] == 'ON': + cs = 1 + else: + cs = int(row[1]) + row.close() + return cs + + def _detect_collations(self, connection): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + collations = {} + if self.server_version_info < (4, 1, 0): + pass + else: + charset = self._connection_charset + rs = connection.execute('SHOW COLLATION') + for row in self._compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" + + row = self._compat_fetchone( + connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), + charset=self._connection_charset) + + if not row: + mode = '' + else: + mode = row[1] or '' + # 4.0 + if mode.isdigit(): + mode_no = int(mode) + mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' + + return 'ANSI_QUOTES' in mode + + def _show_create_table(self, connection, table, charset=None, + full_name=None): + """Run SHOW CREATE TABLE for a ``Table``.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "SHOW CREATE TABLE %s" % full_name + + rp = None + try: + try: + rp = connection.execute(st) + except exc.SQLError, e: + if self._extract_error_code(e) == 1146: + raise exc.NoSuchTableError(full_name) + else: + raise + row = self._compat_fetchone(rp, charset=charset) + if not row: + raise exc.NoSuchTableError(full_name) + return row[1].strip() + finally: + if rp: + rp.close() + + return sql + + def _describe_table(self, connection, table, charset=None, + full_name=None): + """Run DESCRIBE for a ``Table`` and return processed rows.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "DESCRIBE %s" % full_name + + rp, rows = None, None + try: + try: + rp = connection.execute(st) + except exc.SQLError, e: + if self._extract_error_code(e) == 1146: + raise exc.NoSuchTableError(full_name) + else: + raise + rows = self._compat_fetchall(rp, charset=charset) + finally: + if rp: + rp.close() + return rows + +class ReflectedState(object): + """Stores raw information about a SHOW CREATE TABLE statement.""" + + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.constraints = [] + +class MySQLTableDefinitionParser(object): + """Parses the results of a SHOW CREATE TABLE statement.""" + + def __init__(self, dialect, preparer): + self.dialect = dialect + self.preparer = preparer + self._prep_regexes() + + def parse(self, show_create, charset): + state = ReflectedState() + state.charset = charset + for line in re.split(r'\r?\n', show_create): + if line.startswith(' ' + self.preparer.initial_quote): + self._parse_column(line, state) + # a regular table options line + elif line.startswith(') '): + self._parse_table_options(line, state) + # an ANSI-mode table options line + elif line == ')': + pass + elif line.startswith('CREATE '): + self._parse_table_name(line, state) + # Not present in real reflection, but may be if loading from a file. + elif not line: + pass + else: + type_, spec = self._parse_constraints(line) + if type_ is None: + util.warn("Unknown schema content: %r" % line) + elif type_ == 'key': + state.keys.append(spec) + elif type_ == 'constraint': + state.constraints.append(spec) + else: + pass + + return state + + def _parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + line + A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + spec['columns'] = self._parse_keyexprs(spec['columns']) + return 'key', spec + + # CONSTRAINT + m = self._re_constraint.match(line) + if m: + spec = m.groupdict() + spec['table'] = \ + self.preparer.unformat_identifiers(spec['table']) + spec['local'] = [c[0] + for c in self._parse_keyexprs(spec['local'])] + spec['foreign'] = [c[0] + for c in self._parse_keyexprs(spec['foreign'])] + return 'constraint', spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return 'partition', line + + # No match. + return (None, line) + + def _parse_table_name(self, line, state): + """Extract the table name. + + line + The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if m: + state.table_name = cleanup(m.group('name')) + + def _parse_table_options(self, line, state): + """Build a dictionary of all reflected table-level options. + + line + The final line of SHOW CREATE TABLE output. + """ + + options = {} + + if not line or line == ')': + pass + + else: + r_eq_trim = self._re_options_util['='] + + for regex, cleanup in self._pr_options: + m = regex.search(line) + if not m: + continue + directive, value = m.group('directive'), m.group('val') + directive = r_eq_trim.sub('', directive).lower() + if cleanup: + value = cleanup(value) + options[directive] = value + + for nope in ('auto_increment', 'data_directory', 'index_directory'): + options.pop(nope, None) + + for opt, val in options.items(): + state.table_options['mysql_%s' % opt] = val + + def _parse_column(self, line, state): + """Extract column details. + + Falls back to a 'minimal support' variant if full parse fails. + + line + Any column-bearing line from SHOW CREATE TABLE + """ + + charset = state.charset + spec = None + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec['full'] = True + else: + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec['full'] = False + if not spec: + util.warn("Unknown column definition %r" % line) + return + if not spec['full']: + util.warn("Incomplete reflection of column definition %r" % line) + + name, type_, args, notnull = \ + spec['name'], spec['coltype'], spec['arg'], spec['notnull'] + + # Convention says that TINYINT(1) columns == BOOLEAN + if type_ == 'tinyint' and args == '1': + type_ = 'boolean' + args = None + + try: + col_type = self.dialect.ischema_names[type_] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (type_, name)) + col_type = sqltypes.NullType + + # Column type positional arguments eg. varchar(32) + if args is None or args == '': + type_args = [] + elif args[0] == "'" and args[-1] == "'": + type_args = self._re_csv_str.findall(args) + else: + type_args = [int(v) for v in self._re_csv_int.findall(args)] + + # Column type keyword options + type_kw = {} + for kw in ('unsigned', 'zerofill'): + if spec.get(kw, False): + type_kw[kw] = True + for kw in ('charset', 'collate'): + if spec.get(kw, False): + type_kw[kw] = spec[kw] + + if type_ == 'enum': + type_kw['quoting'] = 'quoted' + + type_instance = col_type(*type_args, **type_kw) + + col_args, col_kw = [], {} + + # NOT NULL + col_kw['nullable'] = True + if spec.get('notnull', False): + col_kw['nullable'] = False + + # AUTO_INCREMENT + if spec.get('autoincr', False): + col_kw['autoincrement'] = True + elif issubclass(col_type, sqltypes.Integer): + col_kw['autoincrement'] = False + + # DEFAULT + default = spec.get('default', None) + + if default == 'NULL': + # eliminates the need to deal with this later. + default = None + + col_d = dict(name=name, type=type_instance, default=default) + col_d.update(col_kw) + state.columns.append(col_d) + + def _describe_to_create(self, table_name, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. + + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. + + `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. + """ + + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = \ + [row[i] for i in (0, 1, 2, 4, 5)] + + line = [' '] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append('NOT NULL') + if default: + if 'auto_increment' in default: + pass + elif (col_type.startswith('timestamp') and + default.startswith('C')): + line.append('DEFAULT') + line.append(default) + elif default == 'NULL': + line.append('DEFAULT') + line.append(default) + else: + line.append('DEFAULT') + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) + + buffer.append(' '.join(line)) + + return ''.join([('CREATE TABLE %s (\n' % + self.preparer.quote_identifier(table_name)), + ',\n'.join(buffer), + '\n) ']) + + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" + + return self._re_keyexprs.findall(identifiers) + + def _prep_regexes(self): + """Pre-compile regular expressions.""" + + self._re_columns = [] + self._pr_options = [] + self._re_options_util = {} + + _final = self.preparer.final_quote + + quotes = dict(zip(('iq', 'fq', 'esc_fq'), + [re.escape(s) for s in + (self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final))])) + + self._pr_name = _pr_compile( + r'^CREATE (?:\w+ +)?TABLE +' + r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes, + self.preparer._unescape_identifier) + + # `col`,`col2`(32),`col3`(15) DESC + # + # Note: ASC and DESC aren't reflected, so we'll punt... + self._re_keyexprs = _re_compile( + r'(?:' + r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)' + r'(?:\((\d+)\))?(?=\,|$))+' % quotes) + + # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' + self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27') + + # 123 or 123,456 + self._re_csv_int = _re_compile(r'\d+') + + + # `colname` <type> [type opts] + # (NOT NULL | NULL) + # DEFAULT ('value' | CURRENT_TIMESTAMP...) + # COMMENT 'comment' + # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT) + # STORAGE (DISK|MEMORY) + self._re_column = _re_compile( + r' ' + r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'(?P<coltype>\w+)' + r'(?:\((?P<arg>(?:\d+|\d+,\d+|' + r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?' + r'(?: +(?P<unsigned>UNSIGNED))?' + r'(?: +(?P<zerofill>ZEROFILL))?' + r'(?: +CHARACTER SET +(?P<charset>\w+))?' + r'(?: +COLLATE +(P<collate>\w+))?' + r'(?: +(?P<notnull>NOT NULL))?' + r'(?: +DEFAULT +(?P<default>' + r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+)' + r'(?:ON UPDATE \w+)?' + r'))?' + r'(?: +(?P<autoincr>AUTO_INCREMENT))?' + r'(?: +COMMENT +(P<comment>(?:\x27\x27|[^\x27])+))?' + r'(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?' + r'(?: +STORAGE +(?P<storage>\w+))?' + r'(?: +(?P<extra>.*))?' + r',?$' + % quotes + ) + + # Fallback, try to parse as little as possible + self._re_column_loose = _re_compile( + r' ' + r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'(?P<coltype>\w+)' + r'(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' + r'.*?(?P<notnull>NOT NULL)?' + % quotes + ) + + # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? + # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) + # KEY_BLOCK_SIZE size | WITH PARSER name + self._re_key = _re_compile( + r' ' + r'(?:(?P<type>\S+) )?KEY' + r'(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?' + r'(?: +USING +(?P<using_pre>\S+))?' + r' +\((?P<columns>.+?)\)' + r'(?: +USING +(?P<using_post>\S+))?' + r'(?: +KEY_BLOCK_SIZE +(?P<keyblock>\S+))?' + r'(?: +WITH PARSER +(?P<parser>\S+))?' + r',?$' + % quotes + ) + + # CONSTRAINT `name` FOREIGN KEY (`local_col`) + # REFERENCES `remote` (`remote_col`) + # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE + # ON DELETE CASCADE ON UPDATE RESTRICT + # + # unique constraints come back as KEYs + kw = quotes.copy() + kw['on'] = 'RESTRICT|CASCASDE|SET NULL|NOACTION' + self._re_constraint = _re_compile( + r' ' + r'CONSTRAINT +' + r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'FOREIGN KEY +' + r'\((?P<local>[^\)]+?)\) REFERENCES +' + r'(?P<table>%(iq)s[^%(fq)s]+%(fq)s(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +' + r'\((?P<foreign>[^\)]+?)\)' + r'(?: +(?P<match>MATCH \w+))?' + r'(?: +ON DELETE (?P<ondelete>%(on)s))?' + r'(?: +ON UPDATE (?P<onupdate>%(on)s))?' + % kw + ) + + # PARTITION + # + # punt! + self._re_partition = _re_compile( + r' ' + r'(?:SUB)?PARTITION') + + # Table-level options (COLLATE, ENGINE, etc.) + for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT', + 'AVG_ROW_LENGTH', 'CHARACTER SET', + 'DEFAULT CHARSET', 'CHECKSUM', + 'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD', + 'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT', + 'KEY_BLOCK_SIZE'): + self._add_option_word(option) + + for option in (('COMMENT', 'DATA_DIRECTORY', 'INDEX_DIRECTORY', + 'PASSWORD', 'CONNECTION')): + self._add_option_string(option) + + self._add_option_regex('UNION', r'\([^\)]+\)') + self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK') + self._add_option_regex('RAID_TYPE', + r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') + self._re_options_util['='] = _re_compile(r'\s*=\s*$') + + def _add_option_string(self, directive): + regex = (r'(?P<directive>%s\s*(?:=\s*)?)' + r'(?:\x27.(?P<val>.*?)\x27(?!\x27)\x27)' % + re.escape(directive)) + self._pr_options.append( + _pr_compile(regex, lambda v: v.replace("''", "'"))) + + def _add_option_word(self, directive): + regex = (r'(?P<directive>%s\s*(?:=\s*)?)' + r'(?P<val>\w+)' % re.escape(directive)) + self._pr_options.append(_pr_compile(regex)) + + def _add_option_regex(self, directive, regex): + regex = (r'(?P<directive>%s\s*(?:=\s*)?)' + r'(?P<val>%s)' % (re.escape(directive), regex)) + self._pr_options.append(_pr_compile(regex)) + +log.class_logger(MySQLTableDefinitionParser) +log.class_logger(MySQLDialect) + + +class _DecodingRowProxy(object): + """Return unicode-decoded values based on type inspection. + + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. + + """ + + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. + + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = charset + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, str): + return item.decode(self.charset) + else: + return item + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, str): + return item.decode(self.charset) + else: + return item + + +class _MySQLIdentifierPreparer(compiler.IdentifierPreparer): + """MySQL-specific schema identifier configuration.""" + + reserved_words = RESERVED_WORDS + + def __init__(self, dialect, **kw): + super(_MySQLIdentifierPreparer, self).__init__(dialect, **kw) + + def _quote_free_identifiers(self, *ids): + """Unilaterally identifier-quote any number of strings.""" + + return tuple([self.quote_identifier(i) for i in ids if i is not None]) + + +class MySQLIdentifierPreparer(_MySQLIdentifierPreparer): + """Traditional MySQL-specific schema identifier configuration.""" + + def __init__(self, dialect): + super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote="`") + + def _escape_identifier(self, value): + return value.replace('`', '``') + + def _unescape_identifier(self, value): + return value.replace('``', '`') + + +class MySQLANSIIdentifierPreparer(_MySQLIdentifierPreparer): + """ANSI_QUOTES MySQL schema identifier configuration.""" + + pass + +def _pr_compile(regex, cleanup=None): + """Prepare a 2-tuple of compiled regex and callable.""" + + return (_re_compile(regex), cleanup) + +def _re_compile(regex): + """Compile a string to regex, I and UNICODE.""" + + return re.compile(regex, re.I | re.UNICODE) + diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py new file mode 100644 index 000000000..6ecfc4b84 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -0,0 +1,194 @@ +"""Support for the MySQL database via the MySQL-python adapter. + +Character Sets +-------------- + +Many MySQL server installations default to a ``latin1`` encoding for client +connections. All data sent through the connection will be converted into +``latin1``, even if you have ``utf8`` or another character set on your tables +and columns. With versions 4.1 and higher, you can change the connection +character set either through server configuration or by including the +``charset`` parameter in the URL used for ``create_engine``. The ``charset`` +option is passed through to MySQL-Python and has the side-effect of also +enabling ``use_unicode`` in the driver by default. For regular encoded +strings, also pass ``use_unicode=0`` in the connection arguments:: + + # set client encoding to utf8; all strings come back as unicode + create_engine('mysql:///mydb?charset=utf8') + + # set client encoding to utf8; all strings come back as utf8 str + create_engine('mysql:///mydb?charset=utf8&use_unicode=0') +""" + +import decimal +import re + +from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutionContext, + MySQLCompiler, NUMERIC, _NumericType) +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import exc, log, schema, sql, types as sqltypes, util + +class MySQL_mysqldbExecutionContext(MySQLExecutionContext): + + @property + def rowcount(self): + if hasattr(self, '_rowcount'): + return self._rowcount + else: + return self.cursor.rowcount + + +class MySQL_mysqldbCompiler(MySQLCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class _DecimalType(_NumericType): + def result_processor(self, dialect): + if self.asdecimal: + return + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + + +class _MySQLdbNumeric(_DecimalType, NUMERIC): + pass + + +class _MySQLdbDecimal(_DecimalType, DECIMAL): + pass + + +class MySQL_mysqldb(MySQLDialect): + driver = 'mysqldb' + supports_unicode_statements = False + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + default_paramstyle = 'format' + execution_ctx_cls = MySQL_mysqldbExecutionContext + statement_compiler = MySQL_mysqldbCompiler + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Numeric: _MySQLdbNumeric, + DECIMAL: _MySQLdbDecimal + } + ) + + @classmethod + def dbapi(cls): + return __import__('MySQLdb') + + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + + def create_connect_args(self, url): + opts = url.translate_connect_args(database='db', username='user', + password='passwd') + opts.update(url.query) + + util.coerce_kw_type(opts, 'compress', bool) + util.coerce_kw_type(opts, 'connect_timeout', int) + util.coerce_kw_type(opts, 'client_flag', int) + util.coerce_kw_type(opts, 'local_infile', int) + # Note: using either of the below will cause all strings to be returned + # as Unicode, both in raw SQL operations and with column types like + # String and MSString. + util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, 'charset', str) + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + + ssl = {} + for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], str) + del opts[key] + if ssl: + opts['ssl'] = ssl + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get('client_flag', 0) + if self.dbapi is not None: + try: + from MySQLdb.constants import CLIENT as CLIENT_FLAGS + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except: + pass + opts['client_flag'] = client_flag + return [[], opts] + + def do_ping(self, connection): + connection.ping() + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.get_server_info()): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _extract_error_code(self, exception): + try: + return exception.orig.args[0] + except AttributeError: + return None + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Note: MySQL-python 1.2.1c7 seems to ignore changes made + # on a connection via set_character_set() + if self.server_version_info < (4, 1, 0): + try: + return connection.connection.character_set_name() + except AttributeError: + # < 1.2.1 final MySQL-python drivers have no charset support. + # a query is needed. + pass + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + + if 'character_set_results' in opts: + return opts['character_set_results'] + try: + return connection.connection.character_set_name() + except AttributeError: + # Still no charset on < 1.2.1 final... + if 'character_set' in opts: + return opts['character_set'] + else: + util.warn( + "Could not detect the connection character set with this " + "combination of MySQL server and MySQL-python. " + "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") + return 'latin1' + + +dialect = MySQL_mysqldb diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py new file mode 100644 index 000000000..1ea7ec864 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -0,0 +1,54 @@ +from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy.engine import base as engine_base +from sqlalchemy import util +import re + +class MySQL_pyodbcExecutionContext(MySQLExecutionContext): + + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + +class MySQL_pyodbc(PyODBCConnector, MySQLDialect): + supports_unicode_statements = False + execution_ctx_cls = MySQL_pyodbcExecutionContext + + pyodbc_driver_name = "MySQL" + + def __init__(self, **kw): + # deal with http://code.google.com/p/pyodbc/issues/detail?id=25 + kw.setdefault('convert_unicode', True) + MySQLDialect.__init__(self, **kw) + PyODBCConnector.__init__(self, **kw) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _extract_error_code(self, exception): + m = re.compile(r"\((\d+)\)").search(str(exception.orig.args)) + c = m.group(1) + if c: + return int(c) + else: + return None + +dialect = MySQL_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py new file mode 100644 index 000000000..6cdc6f438 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -0,0 +1,95 @@ +"""Support for the MySQL database via Jython's zxjdbc JDBC connector. + +JDBC Driver +----------- + +The official MySQL JDBC driver is at +http://dev.mysql.com/downloads/connector/j/. + +""" +import re + +from sqlalchemy import types as sqltypes, util +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext + +class _JDBCBit(BIT): + def result_processor(self, dialect): + """Converts boolean or byte arrays from MySQL Connector/J to longs.""" + def process(value): + if value is None: + return value + if isinstance(value, bool): + return int(value) + v = 0L + for i in value: + v = v << 8 | (i & 0xff) + value = v + return value + return process + + +class MySQL_jdbcExecutionContext(MySQLExecutionContext): + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class MySQL_jdbc(ZxJDBCConnector, MySQLDialect): + execution_ctx_cls = MySQL_jdbcExecutionContext + + jdbc_db_name = 'mysql' + jdbc_driver_name = 'com.mysql.jdbc.Driver' + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Time: sqltypes.Time, + BIT: _JDBCBit + } + ) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs)) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _driver_kwargs(self): + """return kw arg dict to be sent to connect().""" + return dict(CHARSET=self.encoding, yearIsDateType='false') + + def _extract_error_code(self, exception): + # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist + # [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' () + m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args)) + c = m.group(1) + if c: + return int(c) + + def _get_server_version_info(self,connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.dbversion): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + +dialect = MySQL_jdbc diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py new file mode 100644 index 000000000..3b4379ab7 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc + +base.dialect = cx_oracle.dialect diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py new file mode 100644 index 000000000..419ccedb1 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -0,0 +1,904 @@ +# oracle/base.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the Oracle database. + +Oracle version 8 through current (11g at the time of this writing) are supported. + +For information on connecting via specific drivers, see the documentation +for that driver. + +Connect Arguments +----------------- + +The dialect supports several :func:`~sqlalchemy.create_engine()` arguments which +affect the behavior of the dialect regardless of driver in use. + +* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults + to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins. + +* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET. + +Auto Increment Behavior +----------------------- + +SQLAlchemy Table objects which include integer primary keys are usually assumed to have +"autoincrementing" behavior, meaning they can generate their own primary key values upon +INSERT. Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences +to produce these values. With the Oracle dialect, *a sequence must always be explicitly +specified to enable autoincrement*. This is divergent with the majority of documentation +examples which assume the usage of an autoincrement-capable database. To specify sequences, +use the sqlalchemy.schema.Sequence object which is passed to a Column construct:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + Column(...), ... + ) + +This step is also required when using table reflection, i.e. autoload=True:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + autoload=True + ) + +Identifier Casing +----------------- + +In Oracle, the data dictionary represents all case insensitive identifier names +using UPPERCASE text. SQLAlchemy on the other hand considers an all-lower case identifier +name to be case insensitive. The Oracle dialect converts all case insensitive identifiers +to and from those two formats during schema level communication, such as reflection of +tables and indexes. Using an UPPERCASE name on the SQLAlchemy side indicates a +case sensitive identifier, and SQLAlchemy will quote the name - this will cause mismatches +against data dictionary data received from Oracle, so unless identifier names have been +truly created as case sensitive (i.e. using quoted names), all lowercase names should be +used on the SQLAlchemy side. + +Unicode +------- + +SQLAlchemy 0.6 uses the "native unicode" mode provided as of cx_oracle 5. cx_oracle 5.0.2 +or greater is recommended for support of NCLOB. If not using cx_oracle 5, the NLS_LANG +environment variable needs to be set in order for the oracle client library to use +proper encoding, such as "AMERICAN_AMERICA.UTF8". + +Also note that Oracle supports unicode data through the NVARCHAR and NCLOB data types. +When using the SQLAlchemy Unicode and UnicodeText types, these DDL types will be used +within CREATE TABLE statements. Usage of VARCHAR2 and CLOB with unicode text still +requires NLS_LANG to be set. + +LIMIT/OFFSET Support +-------------------- + +Oracle has no support for the LIMIT or OFFSET keywords. Whereas previous versions of SQLAlchemy +used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses +a wrapped subquery approach in conjunction with ROWNUM. The exact methodology is taken from +http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . Note that the +"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt +this was stepping into the bounds of optimization that is better left on the DBA side, but this +prefix can be added by enabling the optimize_limits=True flag on create_engine(). + +ON UPDATE CASCADE +----------------- + +Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based solution +is available at http://asktom.oracle.com/tkyte/update_cascade/index.html . + +When using the SQLAlchemy ORM, the ORM has limited ability to manually issue +cascading updates - specify ForeignKey objects using the +"deferrable=True, initially='deferred'" keyword arguments, +and specify "passive_updates=False" on each relation(). + +Oracle 8 Compatibility +---------------------- + +When using Oracle 8, a "use_ansi=False" flag is available which converts all +JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN +makes use of Oracle's (+) operator. + +Synonym/DBLINK Reflection +------------------------- + +When using reflection with Table objects, the dialect can optionally search for tables +indicated by synonyms that reference DBLINK-ed tables by passing the flag +oracle_resolve_synonyms=True as a keyword argument to the Table construct. If DBLINK +is not in use this flag should be left off. + +""" + +import random, re + +from sqlalchemy import schema as sa_schema +from sqlalchemy import util, sql, log +from sqlalchemy.engine import default, base, reflection +from sqlalchemy.sql import compiler, visitors, expression +from sqlalchemy.sql import operators as sql_operators, functions as sql_functions +from sqlalchemy import types as sqltypes +from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, \ + BLOB, CLOB, TIMESTAMP, FLOAT + +RESERVED_WORDS = set('''SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR DECIMAL UNION PUBLIC AND START UID COMMENT'''.split()) + +class RAW(sqltypes.Binary): + pass +OracleRaw = RAW + +class NCLOB(sqltypes.Text): + __visit_name__ = 'NCLOB' + +VARCHAR2 = VARCHAR +NVARCHAR2 = NVARCHAR + +class NUMBER(sqltypes.Numeric): + __visit_name__ = 'NUMBER' + +class BFILE(sqltypes.Binary): + __visit_name__ = 'BFILE' + +class DOUBLE_PRECISION(sqltypes.Numeric): + __visit_name__ = 'DOUBLE_PRECISION' + +class LONG(sqltypes.Text): + __visit_name__ = 'LONG' + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +colspecs = { + sqltypes.Boolean : _OracleBoolean, +} + +ischema_names = { + 'VARCHAR2' : VARCHAR, + 'NVARCHAR2' : NVARCHAR, + 'CHAR' : CHAR, + 'DATE' : DATE, + 'DATETIME' : DATETIME, + 'NUMBER' : NUMBER, + 'BLOB' : BLOB, + 'BFILE' : BFILE, + 'CLOB' : CLOB, + 'NCLOB' : NCLOB, + 'TIMESTAMP' : TIMESTAMP, + 'RAW' : RAW, + 'FLOAT' : FLOAT, + 'DOUBLE PRECISION' : DOUBLE_PRECISION, + 'LONG' : LONG, +} + + +class OracleTypeCompiler(compiler.GenericTypeCompiler): + # Note: + # Oracle DATE == DATETIME + # Oracle does not allow milliseconds in DATE + # Oracle does not support TIME columns + + def visit_datetime(self, type_): + return self.visit_DATE(type_) + + def visit_float(self, type_): + if type_.precision is None: + return "NUMERIC" + else: + return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : 2} + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_VARCHAR(self, type_): + return "VARCHAR(%(length)s)" % {'length' : type_.length} + + def visit_NVARCHAR(self, type_): + return "NVARCHAR2(%(length)s)" % {'length' : type_.length} + + def visit_text(self, type_): + return self.visit_CLOB(type_) + + def visit_unicode_text(self, type_): + return self.visit_NCLOB(type_) + + def visit_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_boolean(self, type_): + return self.visit_SMALLINT(type_) + + def visit_RAW(self, type_): + return "RAW(%(length)s)" % {'length' : type_.length} + +class OracleCompiler(compiler.SQLCompiler): + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + + def __init__(self, *args, **kwargs): + super(OracleCompiler, self).__init__(*args, **kwargs) + self.__wheres = {} + self._quoted_bind_names = {} + + def visit_mod(self, binary, **kw): + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "LENGTH" + self.function_argspec(fn, **kw) + + def visit_match_op(self, binary, **kw): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def function_argspec(self, fn, **kw): + if len(fn.clauses) > 0: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + + return " FROM DUAL" + + def visit_join(self, join, **kwargs): + if self.dialect.use_ansi: + return compiler.SQLCompiler.visit_join(self, join, **kwargs) + else: + return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + def visit_binary(binary): + if binary.operator == sql_operators.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) + else: + clauses.append(join.onclause) + + for f in froms: + visitors.traverse(f, {}, {'join':visit_join}) + return sql.and_(*clauses) + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + def visit_sequence(self, seq): + return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" + + def visit_alias(self, alias, asfrom=False, **kwargs): + """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" + + if asfrom: + alias_name = isinstance(alias.name, expression._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + def create_out_param(col, i): + bindparam = sql.outparam("ret_%d" % i, type_=col.type) + self.binds[bindparam.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + columnlist = list(expression._select_iterables(returning_cols)) + + # within_columns_clause =False so that labels (foo AS bar) don't render + columns = [self.process(c, within_columns_clause=False) for c in columnlist] + + binds = [create_out_param(c, i) for i, c in enumerate(columnlist)] + + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + + def _TODO_visit_compound_select(self, select): + """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + pass + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``rownum`` criterion. + """ + + if not getattr(select, '_oracle_visit', None): + if not self.dialect.use_ansi: + if self.stack and 'from' in self.stack[-1]: + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + + froms = select._get_display_froms(existingfroms) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause: + select = select.where(whereclause) + select._oracle_visit = True + + if select._limit is not None or select._offset is not None: + # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html + # + # Generalized form of an Oracle pagination query: + # select ... from ( + # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from ( + # select distinct ... where ... order by ... + # ) where ROWNUM <= :limit+:offset + # ) where ora_rn > :offset + # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0 + + # TODO: use annotations instead of clone + attr set ? + select = select._generate() + select._oracle_visit = True + + # Wrap the middle select and add the hint + limitselect = sql.select([c for c in select.c]) + if select._limit and self.dialect.optimize_limits: + limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit) + + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + # If needed, add the limiting clause + if select._limit is not None: + max_row = select._limit + if select._offset is not None: + max_row += select._offset + limitselect.append_whereclause( + sql.literal_column("ROWNUM")<=max_row) + + # If needed, add the ora_rn, and wrap again with offset. + if select._offset is None: + select = limitselect + else: + limitselect = limitselect.column( + sql.literal_column("ROWNUM").label("ora_rn")) + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + offsetselect = sql.select( + [c for c in limitselect.c if c.key!='ora_rn']) + offsetselect._oracle_visit = True + offsetselect._is_wrapper = True + + offsetselect.append_whereclause( + sql.literal_column("ora_rn")>select._offset) + + select = offsetselect + + kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def limit_clause(self, select): + return "" + + def for_update_clause(self, select): + if select.for_update == "nowait": + return " FOR UPDATE NOWAIT" + else: + return super(OracleCompiler, self).for_update_clause(select) + +class OracleDDLCompiler(compiler.DDLCompiler): + + def visit_create_sequence(self, create): + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + + # oracle has no ON UPDATE CASCADE - + # its only available via triggers http://asktom.oracle.com/tkyte/update_cascade/index.html + if constraint.onupdate is not None: + util.warn( + "Oracle does not contain native UPDATE CASCADE " + "functionality - onupdates will not be rendered for foreign keys." + "Consider using deferrable=True, initially='deferred' or triggers.") + + return text + +class OracleDefaultRunner(base.DefaultRunner): + def visit_sequence(self, seq): + return self.execute_string("SELECT " + + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", ()) + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = set([x.lower() for x in RESERVED_WORDS]) + illegal_initial_characters = re.compile(r'[0-9_$]') + + def _bindparam_requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return (lc_value in self.reserved_words + or self.illegal_initial_characters.match(value[0]) + or not self.legal_characters.match(unicode(value)) + ) + + def format_savepoint(self, savepoint): + name = re.sub(r'^_+', '', savepoint.ident) + return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + +class OracleDialect(default.DefaultDialect): + name = 'oracle' + supports_alter = True + supports_unicode_statements = False + supports_unicode_binds = False + max_identifier_length = 30 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + supports_sequences = True + sequences_optional = False + postfetch_lastrowid = False + + default_paramstyle = 'named' + colspecs = colspecs + ischema_names = ischema_names + requires_name_normalize = True + + supports_default_values = False + supports_empty_insert = False + + statement_compiler = OracleCompiler + ddl_compiler = OracleDDLCompiler + type_compiler = OracleTypeCompiler + preparer = OracleIdentifierPreparer + defaultrunner = OracleDefaultRunner + + reflection_options = ('oracle_resolve_synonyms', ) + + + def __init__(self, + use_ansi=True, + optimize_limits=False, + **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.use_ansi = use_ansi + self.optimize_limits = optimize_limits + +# TODO: implement server_version_info for oracle +# def initialize(self, connection): +# super(OracleDialect, self).initialize(connection) +# self.implicit_returning = self.server_version_info > (10, ) and \ +# self.__dict__.get('implicit_returning', True) + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def has_table(self, connection, table_name, schema=None): + if not schema: + schema = self.get_default_schema_name(connection) + cursor = connection.execute( + sql.text("SELECT table_name FROM all_tables " + "WHERE table_name = :name AND owner = :schema_name"), + name=self.denormalize_name(table_name), schema_name=self.denormalize_name(schema)) + return cursor.fetchone() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + if not schema: + schema = self.get_default_schema_name(connection) + cursor = connection.execute( + sql.text("SELECT sequence_name FROM all_sequences " + "WHERE sequence_name = :name AND sequence_owner = :schema_name"), + name=self.denormalize_name(sequence_name), schema_name=self.denormalize_name(schema)) + return cursor.fetchone() is not None + + def normalize_name(self, name): + if name is None: + return None + elif (name.upper() == name and + not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding))): + return name.lower().decode(self.encoding) + else: + return name.decode(self.encoding) + + def denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): + return name.upper().encode(self.encoding) + else: + return name.encode(self.encoding) + + def get_default_schema_name(self, connection): + return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) + + def table_names(self, connection, schema): + # note that table_names() isnt loading DBLINKed or synonym'ed tables + if schema is None: + cursor = connection.execute( + "SELECT table_name FROM all_tables " + "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')") + else: + s = sql.text( + "SELECT table_name FROM all_tables " + "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') " + "AND OWNER = :owner") + cursor = connection.execute(s, owner=self.denormalize_name(schema)) + return [self.normalize_name(row[0]) for row in cursor] + + def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): + """search for a local synonym matching the given desired owner/name. + + if desired_owner is None, attempts to locate a distinct owner. + + returns the actual name, owner, dblink name, and synonym name if found. + """ + + q = "SELECT owner, table_owner, table_name, db_link, synonym_name FROM all_synonyms WHERE " + clauses = [] + params = {} + if desired_synonym: + clauses.append("synonym_name = :synonym_name") + params['synonym_name'] = desired_synonym + if desired_owner: + clauses.append("table_owner = :desired_owner") + params['desired_owner'] = desired_owner + if desired_table: + clauses.append("table_name = :tname") + params['tname'] = desired_table + + q += " AND ".join(clauses) + + result = connection.execute(sql.text(q), **params) + if desired_owner: + row = result.fetchone() + if row: + return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] + else: + return None, None, None, None + else: + rows = result.fetchall() + if len(rows) > 1: + raise AssertionError("There are multiple tables visible to the schema, you must specify owner") + elif len(rows) == 1: + row = rows[0] + return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] + else: + return None, None, None, None + + @reflection.cache + def _prepare_reflection_args(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + + if resolve_synonyms: + actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self.denormalize_name(schema), desired_synonym=self.denormalize_name(table_name)) + else: + actual_name, owner, dblink, synonym = None, None, None, None + if not actual_name: + actual_name = self.denormalize_name(table_name) + if not dblink: + dblink = '' + if not owner: + owner = self.denormalize_name(schema or self.get_default_schema_name(connection)) + return (actual_name, owner, dblink, synonym) + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = "SELECT username FROM all_users ORDER BY username" + cursor = connection.execute(s,) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.get_default_schema_name(connection)) + return self.table_names(connection, schema) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.get_default_schema_name(connection)) + s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") + cursor = connection.execute(s, owner=self.denormalize_name(schema)) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + columns = [] + c = connection.execute(sql.text( + "SELECT column_name, data_type, data_length, data_precision, data_scale, " + "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s " + "WHERE table_name = :table_name AND owner = :owner" % {'dblink': dblink}), + table_name=table_name, owner=schema) + + for row in c: + (colname, coltype, length, precision, scale, nullable, default) = \ + (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) + + # INTEGER if the scale is 0 and precision is null + # NUMBER if the scale and precision are both null + # NUMBER(9,2) if the precision is 9 and the scale is 2 + # NUMBER(3) if the precision is 3 and scale is 0 + #length is ignored except for CHAR and VARCHAR2 + if coltype == 'NUMBER' : + if precision is None and scale is None: + coltype = sqltypes.NUMERIC + elif precision is None and scale == 0: + coltype = sqltypes.INTEGER + else : + coltype = sqltypes.NUMERIC(precision, scale) + elif coltype=='CHAR' or coltype=='VARCHAR2': + coltype = self.ischema_names.get(coltype)(length) + else: + coltype = re.sub(r'\(\d+\)', '', coltype) + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, colname)) + coltype = sqltypes.NULLTYPE + + cdict = { + 'name': colname, + 'type': coltype, + 'nullable': nullable, + 'default': default, + } + columns.append(cdict) + return columns + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + + + info_cache = kw.get('info_cache') + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + indexes = [] + q = sql.text(""" + SELECT a.index_name, a.column_name, b.uniqueness + FROM ALL_IND_COLUMNS%(dblink)s a + INNER JOIN ALL_INDEXES%(dblink)s b + ON a.index_name = b.index_name + AND a.table_owner = b.table_owner + AND a.table_name = b.table_name + WHERE a.table_name = :table_name + AND a.table_owner = :schema + ORDER BY a.index_name, a.column_position""" % {'dblink': dblink}) + rp = connection.execute(q, table_name=self.denormalize_name(table_name), + schema=self.denormalize_name(schema)) + indexes = [] + last_index_name = None + pkeys = self.get_primary_keys(connection, table_name, schema, + resolve_synonyms=resolve_synonyms, + dblink=dblink, + info_cache=kw.get('info_cache')) + uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + for rset in rp: + # don't include the primary key columns + if rset.column_name in [s.upper() for s in pkeys]: + continue + if rset.index_name != last_index_name: + index = dict(name=self.normalize_name(rset.index_name), column_names=[]) + indexes.append(index) + index['unique'] = uniqueness.get(rset.uniqueness, False) + index['column_names'].append(self.normalize_name(rset.column_name)) + last_index_name = rset.index_name + return indexes + + @reflection.cache + def _get_constraint_data(self, connection, table_name, schema=None, + dblink='', **kw): + + rp = connection.execute( + sql.text("""SELECT + ac.constraint_name, + ac.constraint_type, + loc.column_name AS local_column, + rem.table_name AS remote_table, + rem.column_name AS remote_column, + rem.owner AS remote_owner, + loc.position as loc_pos, + rem.position as rem_pos + FROM all_constraints%(dblink)s ac, + all_cons_columns%(dblink)s loc, + all_cons_columns%(dblink)s rem + WHERE ac.table_name = :table_name + AND ac.constraint_type IN ('R','P') + AND ac.owner = :owner + AND ac.owner = loc.owner + AND ac.constraint_name = loc.constraint_name + AND ac.r_owner = rem.owner(+) + AND ac.r_constraint_name = rem.constraint_name(+) + AND (rem.position IS NULL or loc.position=rem.position) + ORDER BY ac.constraint_name, loc.position""" % {'dblink': dblink}), + table_name=table_name, owner=schema) + constraint_data = rp.fetchall() + return constraint_data + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + pkeys = [] + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) + + for row in constraint_data: + #print "ROW:" , row + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + if cons_type == 'P': + pkeys.append(local_column) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + requested_schema = schema # to check later on + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for row in constraint_data: + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + + if cons_type == 'R': + if remote_table is None: + # ticket 363 + util.warn( + ("Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?") % {'dblink':dblink}) + continue + + rec = fkeys[cons_name] + rec['name'] = cons_name + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + if not rec['referred_table']: + if resolve_synonyms: + ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ + self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table) + ) + if ref_synonym: + remote_table = self.normalize_name(ref_synonym) + remote_owner = self.normalize_name(ref_remote_owner) + + rec['referred_table'] = remote_table + + if requested_schema is not None or self.denormalize_name(remote_owner) != schema: + rec['referred_schema'] = remote_owner + + local_cols.append(local_column) + remote_cols.append(remote_column) + + return fkeys.values() + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + info_cache = kw.get('info_cache') + (view_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, view_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + s = sql.text(""" + SELECT text FROM all_views + WHERE owner = :schema + AND view_name = :view_name + """) + rp = connection.execute(s, + view_name=view_name, schema=schema).scalar() + if rp: + return rp.decode(self.encoding) + else: + return None + + + +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = 'outer_join_column' + + def __init__(self, column): + self.column = column + + + diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py new file mode 100644 index 000000000..d8a0c445a --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -0,0 +1,371 @@ +"""Support for the Oracle database via the cx_oracle driver. + +Driver +------ + +The Oracle dialect uses the cx_oracle driver, available at +http://cx-oracle.sourceforge.net/ . The dialect has several behaviors +which are specifically tailored towards compatibility with this module. + +Connecting +---------- + +Connecting with create_engine() uses the standard URL approach of +``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the +host, port, and dbname tokens are converted to a TNS name using the cx_oracle +:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name. + +Additional arguments which may be specified either as query string arguments on the +URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are: + +* *allow_twophase* - enable two-phase transactions. Defaults to ``True``. + +* *auto_convert_lobs* - defaults to True, see the section on LOB objects. + +* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters. + This is required for LOB datatypes but can be disabled to reduce overhead. Defaults + to ``True``. + +* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an + integer value. This value is only available as a URL query string argument. + +* *threaded* - enable multithreaded access to cx_oracle connections. Defaults + to ``True``. Note that this is the opposite default of cx_oracle itself. + + +LOB Objects +----------- + +cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set +is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default, +SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First, +the LOB object requires an active cursor association, meaning if you were to fetch many rows +at once such that cx_oracle had to go back to the database and fetch a new batch of rows, +the LOB objects in the already-fetched rows are now unreadable and will raise an error. +SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read. +The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy +defaults to 50 (cx_oracle normally defaults this to one). + +Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to +"normalize" the results to look more like other DBAPIs. + +The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place +for all statement executions, even plain string-based statements for which SQLA has no awareness +of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases +without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch" +of LOB objects, can be disabled using auto_convert_lobs=False. + +Two Phase Transaction Support +----------------------------- + +Two Phase transactions are implemented using XA transactions. Success has been reported of them +working successfully but this should be regarded as an experimental feature. + +""" + +from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, RESERVED_WORDS +from sqlalchemy.dialects.oracle import base as oracle +from sqlalchemy.engine.default import DefaultExecutionContext +from sqlalchemy.engine import base +from sqlalchemy import types as sqltypes, util +import datetime + +class _OracleDate(sqltypes.Date): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + def process(value): + if not isinstance(value, datetime.datetime): + return value + else: + return value.date() + return process + +class _OracleDateTime(sqltypes.DateTime): + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value, datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year, value.month, + value.day,value.hour, value.minute, value.second) + return process + +# Note: +# Oracle DATE == DATETIME +# Oracle does not allow milliseconds in DATE +# Oracle does not support TIME columns + +# only if cx_oracle contains TIMESTAMP +class _OracleTimestamp(sqltypes.TIMESTAMP): + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value, datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year, value.month, + value.day,value.hour, value.minute, value.second) + return process + +class _LOBMixin(object): + def result_processor(self, dialect): + super_process = super(_LOBMixin, self).result_processor(dialect) + if not dialect.auto_convert_lobs: + return super_process + lob = dialect.dbapi.LOB + def process(value): + if isinstance(value, lob): + if super_process: + return super_process(value.read()) + else: + return value.read() + else: + if super_process: + return super_process(value) + else: + return value + return process + +class _OracleText(_LOBMixin, sqltypes.Text): + def get_dbapi_type(self, dbapi): + return dbapi.CLOB + +class _OracleUnicodeText(_LOBMixin, sqltypes.UnicodeText): + def get_dbapi_type(self, dbapi): + return dbapi.NCLOB + + +class _OracleBinary(_LOBMixin, sqltypes.Binary): + def get_dbapi_type(self, dbapi): + return dbapi.BLOB + + def bind_processor(self, dialect): + return None + + +class _OracleRaw(_LOBMixin, oracle.RAW): + pass + + +colspecs = { + sqltypes.DateTime : _OracleDateTime, + sqltypes.Date : _OracleDate, + sqltypes.Binary : _OracleBinary, + sqltypes.Boolean : oracle._OracleBoolean, + sqltypes.Text : _OracleText, + sqltypes.UnicodeText : _OracleUnicodeText, + sqltypes.TIMESTAMP : _OracleTimestamp, + oracle.RAW: _OracleRaw, +} + +class Oracle_cx_oracleCompiler(OracleCompiler): + def bindparam_string(self, name): + if self.preparer._bindparam_requires_quotes(name): + quoted_name = '"%s"' % name + self._quoted_bind_names[name] = quoted_name + return OracleCompiler.bindparam_string(self, quoted_name) + else: + return OracleCompiler.bindparam_string(self, name) + +class Oracle_cx_oracleExecutionContext(DefaultExecutionContext): + def pre_exec(self): + quoted_bind_names = getattr(self.compiled, '_quoted_bind_names', {}) + if quoted_bind_names: + for param in self.parameters: + for fromname, toname in self.compiled._quoted_bind_names.iteritems(): + param[toname.encode(self.dialect.encoding)] = param[fromname] + del param[fromname] + + if self.dialect.auto_setinputsizes: + self.set_input_sizes(quoted_bind_names, exclude_types=(self.dialect.dbapi.STRING,)) + + if len(self.compiled_parameters) == 1: + for key in self.compiled.binds: + bindparam = self.compiled.binds[key] + name = self.compiled.bind_names[bindparam] + value = self.compiled_parameters[0][name] + if bindparam.isoutparam: + dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if not hasattr(self, 'out_parameters'): + self.out_parameters = {} + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[0][quoted_bind_names.get(name, name)] = self.out_parameters[name] + + + def create_cursor(self): + c = self._connection.connection.cursor() + if self.dialect.arraysize: + c.cursor.arraysize = self.dialect.arraysize + return c + + def get_result_proxy(self): + if hasattr(self, 'out_parameters'): + if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: + for bind, name in self.compiled.bind_names.iteritems(): + if name in self.out_parameters: + type = bind.type + result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect) + if result_processor is not None: + self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue()) + else: + self.out_parameters[name] = self.out_parameters[name].getvalue() + else: + for k in self.out_parameters: + self.out_parameters[k] = self.out_parameters[k].getvalue() + + if self.cursor.description is not None: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect.ORACLE_BINARY_TYPES: + return base.BufferedColumnResultProxy(self) + + if hasattr(self, 'out_parameters') and \ + self.compiled.returning: + + return ReturningResultProxy(self) + else: + return base.ResultProxy(self) + +class ReturningResultProxy(base.FullyBufferedResultProxy): + """Result proxy which stuffs the _returning clause + outparams into the fetch.""" + + def _cursor_description(self): + returning = self.context.compiled.returning + + ret = [] + for c in returning: + if hasattr(c, 'key'): + ret.append((c.key, c.type)) + else: + ret.append((c.anon_label, c.type)) + return ret + + def _buffer_rows(self): + returning = self.context.compiled.returning + return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))] + +class Oracle_cx_oracle(OracleDialect): + execution_ctx_cls = Oracle_cx_oracleExecutionContext + statement_compiler = Oracle_cx_oracleCompiler + driver = "cx_oracle" + colspecs = colspecs + + def __init__(self, + auto_setinputsizes=True, + auto_convert_lobs=True, + threaded=True, + allow_twophase=True, + arraysize=50, **kwargs): + OracleDialect.__init__(self, **kwargs) + self.threaded = threaded + self.arraysize = arraysize + self.allow_twophase = allow_twophase + self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) + self.auto_setinputsizes = auto_setinputsizes + self.auto_convert_lobs = auto_convert_lobs + + def vers(num): + return tuple([int(x) for x in num.split('.')]) + + if hasattr(self.dbapi, 'version'): + cx_oracle_ver = vers(self.dbapi.version) + self.supports_unicode_binds = cx_oracle_ver >= (5, 0) + + if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__: + self.dbapi_type_map = {} + self.ORACLE_BINARY_TYPES = [] + else: + # only use this for LOB objects. using it for strings, dates + # etc. leads to a little too much magic, reflection doesn't know if it should + # expect encoded strings or unicodes, etc. + self.dbapi_type_map = { + self.dbapi.CLOB: oracle.CLOB(), + self.dbapi.NCLOB:oracle.NCLOB(), + self.dbapi.BLOB: oracle.BLOB(), + self.dbapi.BINARY: oracle.RAW(), + } + self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] + + @classmethod + def dbapi(cls): + import cx_Oracle + return cx_Oracle + + def create_connect_args(self, url): + dialect_opts = dict(url.query) + for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', + 'threaded', 'allow_twophase'): + if opt in dialect_opts: + util.coerce_kw_type(dialect_opts, opt, bool) + setattr(self, opt, dialect_opts[opt]) + + if url.database: + # if we have a database, then we have a remote host + port = url.port + if port: + port = int(port) + else: + port = 1521 + dsn = self.dbapi.makedsn(url.host, port, url.database) + else: + # we have a local tnsname + dsn = url.host + + opts = dict( + user=url.username, + password=url.password, + dsn=dsn, + threaded=self.threaded, + twophase=self.allow_twophase, + ) + if 'mode' in url.query: + opts['mode'] = url.query['mode'] + if isinstance(opts['mode'], basestring): + mode = opts['mode'].upper() + if mode == 'SYSDBA': + opts['mode'] = self.dbapi.SYSDBA + elif mode == 'SYSOPER': + opts['mode'] = self.dbapi.SYSOPER + else: + util.coerce_kw_type(opts, 'mode', int) + # Can't set 'handle' or 'pool' via URL query args, use connect_args + + return ([], opts) + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.version.split('.')) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.InterfaceError): + return "not connected" in str(e) + else: + return "ORA-03114" in str(e) or "ORA-03113" in str(e) + + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified.""" + + id = random.randint(0, 2 ** 128) + return (0x1234, "%032x" % id, "%032x" % 9) + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) + + def do_prepare_twophase(self, connection, xid): + connection.connection.prepare() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + pass + +dialect = Oracle_cx_oracle diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py new file mode 100644 index 000000000..a0ad088b2 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -0,0 +1,24 @@ +"""Support for the Oracle database via the zxjdbc JDBC connector.""" +import re + +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.oracle.base import OracleDialect + +class Oracle_jdbc(ZxJDBCConnector, OracleDialect): + + jdbc_db_name = 'oracle' + jdbc_driver_name = 'oracle.jdbc.driver.OracleDriver' + + def create_connect_args(self, url): + hostname = url.host + port = url.port or '1521' + dbname = url.database + jdbc_url = 'jdbc:oracle:thin:@%s:%s:%s' % (hostname, port, dbname) + return [[jdbc_url, url.username, url.password, self.jdbc_driver_name], + self._driver_kwargs()] + + def _get_server_version_info(self, connection): + version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1) + return tuple(int(x) for x in version.split('.')) + +dialect = Oracle_jdbc diff --git a/lib/sqlalchemy/dialects/postgres.py b/lib/sqlalchemy/dialects/postgres.py new file mode 100644 index 000000000..e66989fa7 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgres.py @@ -0,0 +1,9 @@ +# backwards compat with the old name +from sqlalchemy.util import warn_deprecated + +warn_deprecated( + "The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. " + "The new URL format is postgresql[+driver]://<user>:<pass>@<host>/<dbname>" + ) + +from sqlalchemy.dialects.postgresql import *
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py new file mode 100644 index 000000000..af9430a2b --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, zxjdbc + +base.dialect = psycopg2.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py new file mode 100644 index 000000000..874907abc --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -0,0 +1,898 @@ +# postgresql.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the PostgreSQL database. + +For information on connecting using specific drivers, see the documentation section +regarding that driver. + +Sequences/SERIAL +---------------- + +PostgreSQL supports sequences, and SQLAlchemy uses these as the default means of creating +new primary key values for integer-based primary key columns. When creating tables, +SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, +which generates a sequence corresponding to the column and associated with it based on +a naming convention. + +To specify a specific named sequence to be used for primary key generation, use the +:func:`~sqlalchemy.schema.Sequence` construct:: + + Table('sometable', metadata, + Column('id', Integer, Sequence('some_id_seq'), primary_key=True) + ) + +Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of +having the "last insert identifier" available, the sequence is executed independently +beforehand and the new value is retrieved, to be used in the subsequent insert. Note +that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using +"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior +is used. + +PostgreSQL 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports +as well. A future release of SQLA will use this feature by default in lieu of +sequence pre-execution in order to retrieve new primary key values, when available. + +INSERT/UPDATE...RETURNING +------------------------- + +The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes, +but must be explicitly enabled on a per-statement basis:: + + # INSERT..RETURNING + result = table.insert(postgresql_returning=[table.c.col1, table.c.col2]).\\ + values(name='foo') + print result.fetchall() + + # UPDATE..RETURNING + result = table.update(postgresql_returning=[table.c.col1, table.c.col2]).\\ + where(table.c.name=='foo').values(name='bar') + print result.fetchall() + +Indexes +------- + +PostgreSQL supports partial indexes. To create them pass a postgresql_where +option to the Index constructor:: + + Index('my_index', my_table.c.id, postgresql_where=tbl.c.value > 10) + + + +""" + +import re + +from sqlalchemy import schema as sa_schema +from sqlalchemy import sql, schema, exc, util +from sqlalchemy.engine import base, default, reflection +from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes + +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ + CHAR, TEXT, FLOAT, NUMERIC, \ + TIMESTAMP, TIME, DATE, BOOLEAN + +class REAL(sqltypes.Float): + __visit_name__ = "REAL" + +class BYTEA(sqltypes.Binary): + __visit_name__ = 'BYTEA' + +class DOUBLE_PRECISION(sqltypes.Float): + __visit_name__ = 'DOUBLE_PRECISION' + +class INET(sqltypes.TypeEngine): + __visit_name__ = "INET" +PGInet = INET + +class CIDR(sqltypes.TypeEngine): + __visit_name__ = "CIDR" +PGCidr = CIDR + +class MACADDR(sqltypes.TypeEngine): + __visit_name__ = "MACADDR" +PGMacAddr = MACADDR + +class INTERVAL(sqltypes.TypeEngine): + __visit_name__ = 'INTERVAL' +PGInterval = INTERVAL + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' +PGBit = BIT + +class UUID(sqltypes.TypeEngine): + __visit_name__ = 'UUID' +PGUuid = UUID + +class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): + __visit_name__ = 'ARRAY' + + def __init__(self, item_type, mutable=True): + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.mutable = mutable + + def copy_value(self, value): + if value is None: + return None + elif self.mutable: + return list(value) + else: + return value + + def compare_values(self, x, y): + return x == y + + def is_mutable(self): + return self.mutable + + def dialect_impl(self, dialect, **kwargs): + impl = self.__class__.__new__(self.__class__) + impl.__dict__.update(self.__dict__) + impl.item_type = self.item_type.dialect_impl(dialect) + return impl + + def bind_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list, tuple)): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process + + def result_processor(self, dialect): + item_proc = self.item_type.result_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process +PGArray = ARRAY + +colspecs = { + sqltypes.Interval:INTERVAL +} + +ischema_names = { + 'integer' : INTEGER, + 'bigint' : BIGINT, + 'smallint' : SMALLINT, + 'character varying' : VARCHAR, + 'character' : CHAR, + '"char"' : sqltypes.String, + 'name' : sqltypes.String, + 'text' : TEXT, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'real' : REAL, + 'inet': INET, + 'cidr': CIDR, + 'uuid': UUID, + 'bit':BIT, + 'macaddr': MACADDR, + 'double precision' : DOUBLE_PRECISION, + 'timestamp' : TIMESTAMP, + 'timestamp with time zone' : TIMESTAMP, + 'timestamp without time zone' : TIMESTAMP, + 'time with time zone' : TIME, + 'time without time zone' : TIME, + 'date' : DATE, + 'time': TIME, + 'bytea' : BYTEA, + 'boolean' : BOOLEAN, + 'interval':INTERVAL, +} + + + +class PGCompiler(compiler.SQLCompiler): + + def visit_match_op(self, binary, **kw): + return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def post_process_text(self, text): + if '%%' in text: + util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() expressions to '%%'.") + return text.replace('%', '%%') + + def visit_sequence(self, seq): + if seq.optional: + return None + else: + return "nextval('%s')" % self.preparer.format_sequence(seq) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT ALL" + text += " OFFSET " + str(select._offset) + return text + + def get_select_precolumns(self, select): + if select._distinct: + if isinstance(select._distinct, bool): + return "DISTINCT " + elif isinstance(select._distinct, (list, tuple)): + return "DISTINCT ON (" + ', '.join( + [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] + )+ ") " + else: + return "DISTINCT ON (" + unicode(select._distinct) + ") " + else: + return "" + + def for_update_clause(self, select): + if select.for_update == 'nowait': + return " FOR UPDATE NOWAIT" + else: + return super(PGCompiler, self).for_update_clause(select) + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self.process( + self.label_select_column(None, c, asfrom=False), + within_columns_clause=True, + result_map=self.result_map) + for c in expression._select_iterables(returning_cols) + ] + + return 'RETURNING ' + ', '.join(columns) + + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s::timestamp)" % ( + field, self.process(extract.expr)) + +class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + if column.primary_key and \ + len(column.foreign_keys)==0 and \ + column.autoincrement and \ + isinstance(column.type, sqltypes.Integer) and \ + not isinstance(column.type, sqltypes.SmallInteger) and \ + (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if isinstance(column.type, sqltypes.BigInteger): + colspec += " BIGSERIAL" + else: + colspec += " SERIAL" + else: + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + + def visit_create_sequence(self, create): + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + + def visit_create_index(self, create): + preparer = self.preparer + index = create.element + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + ', '.join([preparer.format_column(c) for c in index.columns])) + + if "postgres_where" in index.kwargs: + whereclause = index.kwargs['postgres_where'] + util.warn_deprecated("The 'postgres_where' argument has been renamed to 'postgresql_where'.") + elif 'postgresql_where' in index.kwargs: + whereclause = index.kwargs['postgresql_where'] + else: + whereclause = None + + if whereclause is not None: + compiler = self._compile(whereclause, None) + # this might belong to the compiler class + inlined_clause = str(compiler) % dict( + [(key,bind.value) for key,bind in compiler.binds.iteritems()]) + text += " WHERE " + inlined_clause + return text + + +class PGDefaultRunner(base.DefaultRunner): + def __init__(self, context): + base.DefaultRunner.__init__(self, context) + # craete cursor which won't conflict with a server-side cursor + self.cursor = context._connection.connection.cursor() + + def get_column_default(self, column, isinsert=True): + if column.primary_key: + # pre-execute passive defaults on primary keys + if (isinstance(column.server_default, schema.DefaultClause) and + column.server_default.arg is not None): + return self.execute_string("select %s" % column.server_default.arg) + elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) \ + and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + sch = column.table.schema + # TODO: this has to build into the Sequence object so we can get the quoting + # logic from it + if sch is not None: + exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) + else: + exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) + + if self.dialect.supports_unicode_statements: + return self.execute_string(exc) + else: + return self.execute_string(exc.encode(self.dialect.encoding)) + + return super(PGDefaultRunner, self).get_column_default(column) + + def visit_sequence(self, seq): + if not seq.optional: + return self.execute_string(("select nextval('%s')" % \ + self.dialect.identifier_preparer.format_sequence(seq))) + else: + return None + +class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_INET(self, type_): + return "INET" + + def visit_CIDR(self, type_): + return "CIDR" + + def visit_MACADDR(self, type_): + return "MACADDR" + + def visit_FLOAT(self, type_): + if not type_.precision: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': type_.precision} + + def visit_DOUBLE_PRECISION(self, type_): + return "DOUBLE PRECISION" + + def visit_BIGINT(self, type_): + return "BIGINT" + + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_TIMESTAMP(self, type_): + return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + + def visit_TIME(self, type_): + return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + + def visit_INTERVAL(self, type_): + return "INTERVAL" + + def visit_BIT(self, type_): + return "BIT" + + def visit_UUID(self, type_): + return "UUID" + + def visit_binary(self, type_): + return self.visit_BYTEA(type_) + + def visit_BYTEA(self, type_): + return "BYTEA" + + def visit_REAL(self, type_): + return "REAL" + + def visit_ARRAY(self, type_): + return self.process(type_.item_type) + '[]' + +class PGIdentifierPreparer(compiler.IdentifierPreparer): + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace('""','"') + return value + +class PGInspector(reflection.Inspector): + + def __init__(self, conn): + reflection.Inspector.__init__(self, conn) + + def get_table_oid(self, table_name, schema=None): + """Return the oid from `table_name` and `schema`.""" + + return self.dialect.get_table_oid(self.conn, table_name, schema, + info_cache=self.info_cache) + + +class PGDialect(default.DefaultDialect): + name = 'postgresql' + supports_alter = True + max_identifier_length = 63 + supports_sane_rowcount = True + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + + supports_default_values = True + supports_empty_insert = False + default_paramstyle = 'pyformat' + ischema_names = ischema_names + colspecs = colspecs + + statement_compiler = PGCompiler + ddl_compiler = PGDDLCompiler + type_compiler = PGTypeCompiler + preparer = PGIdentifierPreparer + defaultrunner = PGDefaultRunner + inspector = PGInspector + isolation_level = None + + def __init__(self, isolation_level=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + + def initialize(self, connection): + super(PGDialect, self).initialize(connection) + self.implicit_returning = self.server_version_info > (8, 3) and \ + self.__dict__.get('implicit_returning', True) + + def visit_pool(self, pool): + if self.isolation_level is not None: + class SetIsolationLevel(object): + def __init__(self, isolation_level): + self.isolation_level = isolation_level + + def connect(self, conn, rec): + cursor = conn.cursor() + cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s" + % self.isolation_level) + cursor.execute("COMMIT") + cursor.close() + pool.add_listener(SetIsolationLevel(self.isolation_level)) + + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.execute("PREPARE TRANSACTION '%s'" % xid) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions + # Must find out a way how to make the dbapi not open a transaction. + connection.execute("ROLLBACK") + connection.execute("ROLLBACK PREPARED '%s'" % xid) + connection.execute("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + connection.execute("ROLLBACK") + connection.execute("COMMIT PREPARED '%s'" % xid) + connection.execute("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) + return [row[0] for row in resultset] + + def get_default_schema_name(self, connection): + return connection.scalar("select current_schema()") + + def has_table(self, connection, table_name, schema=None): + # seems like case gets folded in pg_class... + if schema is None: + cursor = connection.execute( + sql.text("select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=:name", + bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode)] + ) + ) + else: + cursor = connection.execute( + sql.text("select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name", + bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode), + sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] + ) + ) + return bool(cursor.fetchone()) + + def has_sequence(self, connection, sequence_name): + cursor = connection.execute( + sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND " + "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' " + "AND nspname != 'information_schema' AND relname = :seqname)", + bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)] + )) + return bool(cursor.fetchone()) + + def table_names(self, connection, schema): + result = connection.execute( + sql.text(u"""SELECT relname + FROM pg_class c + WHERE relkind = 'r' + AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)""" % schema, + typemap = {'relname':sqltypes.Unicode} + ) + ) + return [row[0] for row in result] + + def _get_server_version_info(self, connection): + v = connection.execute("select version()").scalar() + m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v) + if not m: + raise AssertionError("Could not determine version from string '%s'" % v) + return tuple([int(x) for x in m.group(1, 2, 3)]) + + @reflection.cache + def get_table_oid(self, connection, table_name, schema=None, **kw): + """Fetch the oid for schema.table_name. + + Several reflection methods require the table oid. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + table_oid = None + if schema is not None: + schema_where_clause = "n.nspname = :schema" + else: + schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" + query = """ + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE (%s) + AND c.relname = :table_name AND c.relkind in ('r','v') + """ % schema_where_clause + # Since we're binding to unicode, table_name and schema_name must be + # unicode. + table_name = unicode(table_name) + if schema is not None: + schema = unicode(schema) + s = sql.text(query, bindparams=[ + sql.bindparam('table_name', type_=sqltypes.Unicode), + sql.bindparam('schema', type_=sqltypes.Unicode) + ], + typemap={'oid':sqltypes.Integer} + ) + c = connection.execute(s, table_name=table_name, schema=schema) + table_oid = c.scalar() + if table_oid is None: + raise exc.NoSuchTableError(table_name) + return table_oid + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = """ + SELECT nspname + FROM pg_namespace + ORDER BY nspname + """ + rp = connection.execute(s) + # what about system tables? + schema_names = [row[0].decode(self.encoding) for row in rp \ + if not row[0].startswith('pg_')] + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.get_default_schema_name(connection) + table_names = self.table_names(connection, current_schema) + return table_names + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.get_default_schema_name(connection) + s = """ + SELECT relname + FROM pg_class c + WHERE relkind = 'v' + AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) + """ % dict(schema=current_schema) + view_names = [row[0].decode(self.encoding) for row in connection.execute(s)] + return view_names + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.get_default_schema_name(connection) + s = """ + SELECT definition FROM pg_views + WHERE schemaname = :schema + AND viewname = :view_name + """ + rp = connection.execute(sql.text(s), + view_name=view_name, schema=current_schema) + if rp: + view_def = rp.scalar().decode(self.encoding) + return view_def + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + SQL_COLS = """ + SELECT a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) + AS DEFAULT, + a.attnotnull, a.attnum, a.attrelid as table_oid + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = :table_oid + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + """ + s = sql.text(SQL_COLS, + bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], + typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode} + ) + c = connection.execute(s, table_oid=table_oid) + rows = c.fetchall() + domains = self._load_domains(connection) + # format columns + columns = [] + for name, format_type, default, notnull, attnum, table_oid in rows: + ## strip (30) from character varying(30) + attype = re.search('([^\([]+)', format_type).group(1) + nullable = not notnull + is_array = format_type.endswith('[]') + try: + charlen = re.search('\(([\d,]+)\)', format_type).group(1) + except: + charlen = False + numericprec = False + numericscale = False + if attype == 'numeric': + if charlen is False: + numericprec, numericscale = (None, None) + else: + numericprec, numericscale = charlen.split(',') + charlen = False + if attype == 'double precision': + numericprec, numericscale = (53, False) + charlen = False + if attype == 'integer': + numericprec, numericscale = (32, 0) + charlen = False + args = [] + for a in (charlen, numericprec, numericscale): + if a is None: + args.append(None) + elif a is not False: + args.append(int(a)) + kwargs = {} + if attype == 'timestamp with time zone': + kwargs['timezone'] = True + elif attype == 'timestamp without time zone': + kwargs['timezone'] = False + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + else: + if attype in domains: + domain = domains[attype] + if domain['attype'] in self.ischema_names: + # A table can't override whether the domain is nullable. + nullable = domain['nullable'] + if domain['default'] and not default: + # It can, however, override the default value, but can't set it to null. + default = domain['default'] + coltype = self.ischema_names[domain['attype']] + else: + coltype = None + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = PGArray(coltype) + else: + util.warn("Did not recognize type '%s' of column '%s'" % + (attype, name)) + coltype = sqltypes.NULLTYPE + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + autoincrement = True + # the default is related to a Sequence + sch = schema + if '.' not in match.group(2) and sch is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / "quote schema" + default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) + + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default, autoincrement=autoincrement) + columns.append(column_info) + return columns + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + PK_SQL = """ + SELECT attname FROM pg_attribute + WHERE attrelid = ( + SELECT indexrelid FROM pg_index i + WHERE i.indrelid = :table_oid + AND i.indisprimary = 't') + ORDER BY attnum + """ + t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + primary_keys = [r[0] for r in c.fetchall()] + return primary_keys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + preparer = self.identifier_preparer + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + FK_SQL = """ + SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = :table AND r.contype = 'f' + ORDER BY 1 + """ + + t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) + c = connection.execute(t, table=table_oid) + fkeys = [] + for conname, condef in c.fetchall(): + m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() + (constrained_columns, referred_schema, referred_table, referred_columns) = m + constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] + if referred_schema: + referred_schema = preparer._unquote_identifier(referred_schema) + elif schema is not None and schema == self.get_default_schema_name(connection): + # no schema (i.e. its the default schema), and the table we're + # reflecting has the default schema explicit, then use that. + # i.e. try to use the user's conventions + referred_schema = schema + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] + fkey_d = { + 'name' : conname, + 'constrained_columns' : constrained_columns, + 'referred_schema' : referred_schema, + 'referred_table' : referred_table, + 'referred_columns' : referred_columns + } + fkeys.append(fkey_d) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + IDX_SQL = """ + SELECT c.relname, i.indisunique, i.indexprs, i.indpred, + a.attname + FROM pg_index i, pg_class c, pg_attribute a + WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid + AND a.attrelid = i.indexrelid AND i.indisprimary = 'f' + ORDER BY c.relname, a.attnum + """ + t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + index_names = {} + indexes = [] + sv_idx_name = None + for row in c.fetchall(): + idx_name, unique, expr, prd, col = row + if expr: + if idx_name != sv_idx_name: + util.warn( + "Skipped unsupported reflection of expression-based index %s" + % idx_name) + sv_idx_name = idx_name + continue + if prd and not idx_name == sv_idx_name: + util.warn( + "Predicate of partial index %s ignored during reflection" + % idx_name) + sv_idx_name = idx_name + if idx_name in index_names: + index_d = index_names[idx_name] + else: + index_d = {'column_names':[]} + indexes.append(index_d) + index_names[idx_name] = index_d + index_d['name'] = idx_name + index_d['column_names'].append(col) + index_d['unique'] = unique + return indexes + + def _load_domains(self, connection): + ## Load data types for domains: + SQL_DOMAINS = """ + SELECT t.typname as "name", + pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", + not t.typnotnull as "nullable", + t.typdefault as "default", + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema" + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid + WHERE t.typtype = 'd' + """ + + s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode}) + c = connection.execute(s) + + domains = {} + for domain in c.fetchall(): + ## strip (30) from character varying(30) + attype = re.search('([^\(]+)', domain['attype']).group(1) + if domain['visible']: + # 'visible' just means whether or not the domain is in a + # schema that's on the search path -- or not overriden by + # a schema with higher presedence. If it's not visible, + # it will be prefixed with the schema-name when it's used. + name = domain['name'] + else: + name = "%s.%s" % (domain['schema'], domain['name']) + + domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']} + + return domains + diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py new file mode 100644 index 000000000..e8dd03113 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -0,0 +1,84 @@ +"""Support for the PostgreSQL database via the pg8000 driver. + +Connecting +---------- + +URLs are of the form `postgresql+pg8000://user@password@host:port/dbname[?key=value&key=value...]`. + +Unicode +------- + +pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file +in order to use encodings other than ascii. Set this value to the same value as +the "encoding" parameter on create_engine(), usually "utf-8". + +Interval +-------- + +Passing data from/to the Interval type is not supported as of yet. + +""" +from sqlalchemy.engine import default +import decimal +from sqlalchemy import util +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler + +class _PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + +class PostgreSQL_pg8000ExecutionContext(default.DefaultExecutionContext): + pass + +class PostgreSQL_pg8000Compiler(PGCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + +class PostgreSQL_pg8000(PGDialect): + driver = 'pg8000' + + supports_unicode_statements = True + + supports_unicode_binds = True + + default_paramstyle = 'format' + supports_sane_multi_rowcount = False + execution_ctx_cls = PostgreSQL_pg8000ExecutionContext + statement_compiler = PostgreSQL_pg8000Compiler + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : _PGNumeric, + sqltypes.Float: sqltypes.Float, # prevents _PGNumeric from being used + } + ) + + @classmethod + def dbapi(cls): + return __import__('pg8000').dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + return "connection is closed" in str(e) + +dialect = PostgreSQL_pg8000 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py new file mode 100644 index 000000000..a428878ae --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -0,0 +1,147 @@ +"""Support for the PostgreSQL database via the psycopg2 driver. + +Driver +------ + +The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . +The dialect has several behaviors which are specifically tailored towards compatibility +with this module. + +Note that psycopg1 is **not** supported. + +Connecting +---------- + +URLs are of the form `postgresql+psycopg2://user@password@host:port/dbname[?key=value&key=value...]`. + +psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: + +* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support + this feature. What this essentially means from a psycopg2 point of view is that the cursor is + created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows + are not immediately pre-fetched and buffered after statement execution, but are instead left + on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` + uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows + at a time are fetched over the wire to reduce conversational overhead. + +* *isolation_level* - Sets the transaction isolation level for each transaction + within the engine. Valid isolation levels are `READ_COMMITTED`, + `READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`. + +Transactions +------------ + +The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. + + +""" + +import decimal, random, re +from sqlalchemy import util +from sqlalchemy.engine import base, default +from sqlalchemy.sql import expression +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler + +class _PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + + +# TODO: filter out 'FOR UPDATE' statements +SERVER_SIDE_CURSOR_RE = re.compile( + r'\s*SELECT', + re.I | re.UNICODE) + +class PostgreSQL_psycopg2ExecutionContext(default.DefaultExecutionContext): + def create_cursor(self): + # TODO: coverage for server side cursors + select.for_update() + is_server_side = \ + self.dialect.server_side_cursors and \ + ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) + and not getattr(self.compiled.statement, 'for_update', False)) \ + or \ + ( + (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) + and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) + ) + + self.__is_server_side = is_server_side + if is_server_side: + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) + return self._connection.connection.cursor(ident) + else: + return self._connection.connection.cursor() + + def get_result_proxy(self): + if self.__is_server_side: + return base.BufferedRowResultProxy(self) + else: + return base.ResultProxy(self) + +class PostgreSQL_psycopg2Compiler(PGCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + def post_process_text(self, text): + return text.replace('%', '%%') + +class PostgreSQL_psycopg2(PGDialect): + driver = 'psycopg2' + supports_unicode_statements = False + default_paramstyle = 'pyformat' + supports_sane_multi_rowcount = False + execution_ctx_cls = PostgreSQL_psycopg2ExecutionContext + statement_compiler = PostgreSQL_psycopg2Compiler + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : _PGNumeric, + sqltypes.Float: sqltypes.Float, # prevents _PGNumeric from being used + } + ) + + def __init__(self, server_side_cursors=False, **kwargs): + PGDialect.__init__(self, **kwargs) + self.server_side_cursors = server_side_cursors + + @classmethod + def dbapi(cls): + psycopg = __import__('psycopg2') + return psycopg + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + elif isinstance(e, self.dbapi.InterfaceError): + return 'connection already closed' in str(e) or 'cursor already closed' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + # yes, it really says "losed", not "closed" + return "losed the connection unexpectedly" in str(e) + else: + return False + +dialect = PostgreSQL_psycopg2 + diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py new file mode 100644 index 000000000..975006d92 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -0,0 +1,80 @@ +"""Support for the PostgreSQL database via py-postgresql. + +Connecting +---------- + +URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`. + + +""" +from sqlalchemy.engine import default +import decimal +from sqlalchemy import util +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, PGDefaultRunner + +class PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + +class PostgreSQL_pypostgresqlExecutionContext(default.DefaultExecutionContext): + pass + +class PostgreSQL_pypostgresqlDefaultRunner(PGDefaultRunner): + def execute_string(self, stmt, params=None): + return PGDefaultRunner.execute_string(self, stmt, params or ()) + +class PostgreSQL_pypostgresql(PGDialect): + driver = 'pypostgresql' + + supports_unicode_statements = True + + supports_unicode_binds = True + description_encoding = None + + defaultrunner = PostgreSQL_pypostgresqlDefaultRunner + + default_paramstyle = 'format' + + supports_sane_rowcount = False # alas....posting a bug now + + supports_sane_multi_rowcount = False + + execution_ctx_cls = PostgreSQL_pypostgresqlExecutionContext + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : PGNumeric, + sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used + } + ) + + @classmethod + def dbapi(cls): + from postgresql.driver import dbapi20 + return dbapi20 + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + else: + opts['port'] = 5432 + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + return "connection is closed" in str(e) + +dialect = PostgreSQL_pypostgresql diff --git a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py new file mode 100644 index 000000000..b707d2d9e --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py @@ -0,0 +1,28 @@ +"""Support for the PostgreSQL database via the zxjdbc JDBC connector. + +JDBC Driver +----------- + +The official Postgresql JDBC driver is at http://jdbc.postgresql.org/. + +""" +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.postgresql.base import PGCompiler, PGDialect + +class PostgreSQL_jdbcCompiler(PGCompiler): + + def post_process_text(self, text): + # Don't escape '%' like PGCompiler + return text + + +class PostgreSQL_jdbc(ZxJDBCConnector, PGDialect): + statement_compiler = PostgreSQL_jdbcCompiler + + jdbc_db_name = 'postgresql' + jdbc_driver_name = 'org.postgresql.Driver' + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.dbversion.split('.')) + +dialect = PostgreSQL_jdbc diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py new file mode 100644 index 000000000..3cc08870f --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.sqlite import base, pysqlite + +# default dialect +base.dialect = pysqlite.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py new file mode 100644 index 000000000..8dea91d0a --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -0,0 +1,526 @@ +# sqlite.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the SQLite database. + +For information on connecting using a specific driver, see the documentation +section regarding that driver. + +Date and Time Types +------------------- + +SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide +out of the box functionality for translating values between Python `datetime` objects +and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime` +and related types provide date formatting and parsing functionality when SQlite is used. +The implementation classes are :class:`_SLDateTime`, :class:`_SLDate` and :class:`_SLTime`. +These types represent dates and times as ISO formatted strings, which also nicely +support ordering. There's no reliance on typical "libc" internals for these functions +so historical dates are fully supported. + + +""" + +import datetime, re, time + +from sqlalchemy import schema as sa_schema +from sqlalchemy import sql, exc, pool, DefaultClause +from sqlalchemy.engine import default +from sqlalchemy.engine import reflection +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from sqlalchemy.sql import compiler, functions as sql_functions +from sqlalchemy.util import NoneType + +from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\ + FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\ + TIMESTAMP, VARCHAR + + +class _NumericMixin(object): + def bind_processor(self, dialect): + type_ = self.asdecimal and str or float + def process(value): + if value is not None: + return type_(value) + else: + return value + return process + +class _SLNumeric(_NumericMixin, sqltypes.Numeric): + pass + +class _SLFloat(_NumericMixin, sqltypes.Float): + pass + +# since SQLite has no date types, we're assuming that SQLite via ODBC +# or JDBC would similarly have no built in date support, so the "string" based logic +# would apply to all implementing dialects. +class _DateTimeMixin(object): + def _bind_processor(self, format, elements): + def process(value): + if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)): + raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.") + elif value is not None: + return format % tuple([getattr(value, attr, 0) for attr in elements]) + else: + return None + return process + + def _result_processor(self, fn, regexp): + def process(value): + if value is not None: + return fn(*[int(x or 0) for x in regexp.match(value).groups()]) + else: + return None + return process + +class _SLDateTime(_DateTimeMixin, sqltypes.DateTime): + __legacy_microseconds__ = False + + def bind_processor(self, dialect): + if self.__legacy_microseconds__: + return self._bind_processor( + "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", + ("year", "month", "day", "hour", "minute", "second", "microsecond") + ) + else: + return self._bind_processor( + "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", + ("year", "month", "day", "hour", "minute", "second", "microsecond") + ) + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?") + def result_processor(self, dialect): + return self._result_processor(datetime.datetime, self._reg) + +class _SLDate(_DateTimeMixin, sqltypes.Date): + def bind_processor(self, dialect): + return self._bind_processor( + "%4.4d-%2.2d-%2.2d", + ("year", "month", "day") + ) + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect): + return self._result_processor(datetime.date, self._reg) + +class _SLTime(_DateTimeMixin, sqltypes.Time): + __legacy_microseconds__ = False + + def bind_processor(self, dialect): + if self.__legacy_microseconds__: + return self._bind_processor( + "%2.2d:%2.2d:%2.2d.%s", + ("hour", "minute", "second", "microsecond") + ) + else: + return self._bind_processor( + "%2.2d:%2.2d:%2.2d.%06d", + ("hour", "minute", "second", "microsecond") + ) + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect): + return self._result_processor(datetime.time, self._reg) + + +class _SLBoolean(sqltypes.Boolean): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value and 1 or 0 + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value == 1 + return process + +colspecs = { + sqltypes.Boolean: _SLBoolean, + sqltypes.Date: _SLDate, + sqltypes.DateTime: _SLDateTime, + sqltypes.Float: _SLFloat, + sqltypes.Numeric: _SLNumeric, + sqltypes.Time: _SLTime, +} + +ischema_names = { + 'BLOB': sqltypes.BLOB, + 'BOOL': sqltypes.BOOLEAN, + 'BOOLEAN': sqltypes.BOOLEAN, + 'CHAR': sqltypes.CHAR, + 'DATE': sqltypes.DATE, + 'DATETIME': sqltypes.DATETIME, + 'DECIMAL': sqltypes.DECIMAL, + 'FLOAT': sqltypes.FLOAT, + 'INT': sqltypes.INTEGER, + 'INTEGER': sqltypes.INTEGER, + 'NUMERIC': sqltypes.NUMERIC, + 'REAL': sqltypes.Numeric, + 'SMALLINT': sqltypes.SMALLINT, + 'TEXT': sqltypes.TEXT, + 'TIME': sqltypes.TIME, + 'TIMESTAMP': sqltypes.TIMESTAMP, + 'VARCHAR': sqltypes.VARCHAR, +} + + + +class SQLiteCompiler(compiler.SQLCompiler): + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update({ + 'month': '%m', + 'day': '%d', + 'year': '%Y', + 'second': '%S', + 'hour': '%H', + 'doy': '%j', + 'minute': '%M', + 'epoch': '%s', + 'dow': '%w', + 'week': '%W' + }) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "length%s" % self.function_argspec(fn) + + def visit_cast(self, cast, **kwargs): + if self.dialect.supports_cast: + return super(SQLiteCompiler, self).visit_cast(cast) + else: + return self.process(cast.clause) + + def visit_extract(self, extract): + try: + return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( + self.extract_map[extract.field], self.process(extract.expr)) + except KeyError: + raise exc.ArgumentError( + "%s is not a valid extract argument." % extract.field) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT -1" + text += " OFFSET " + str(select._offset) + else: + text += " OFFSET 0" + return text + + def for_update_clause(self, select): + # sqlite has no "FOR UPDATE" AFAICT + return '' + + +class SQLiteDDLCompiler(compiler.DDLCompiler): + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + +class SQLiteTypeCompiler(compiler.GenericTypeCompiler): + def visit_binary(self, type_): + return self.visit_BLOB(type_) + +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = set([ + 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', + 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', + 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', + 'conflict', 'constraint', 'create', 'cross', 'current_date', + 'current_time', 'current_timestamp', 'database', 'default', + 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', + 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', + 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', + 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', + 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is', + 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural', + 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer', + 'plan', 'pragma', 'primary', 'query', 'raise', 'references', + 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback', + 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', + 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', + 'vacuum', 'values', 'view', 'virtual', 'when', 'where', + ]) + +class SQLiteDialect(default.DefaultDialect): + name = 'sqlite' + supports_alter = False + supports_unicode_statements = True + supports_unicode_binds = True + supports_default_values = True + supports_empty_insert = False + supports_cast = True + + default_paramstyle = 'qmark' + statement_compiler = SQLiteCompiler + ddl_compiler = SQLiteDDLCompiler + type_compiler = SQLiteTypeCompiler + preparer = SQLiteIdentifierPreparer + ischema_names = ischema_names + colspecs = colspecs + isolation_level = None + + def __init__(self, isolation_level=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + if isolation_level and isolation_level not in ('SERIALIZABLE', + 'READ UNCOMMITTED'): + raise exc.ArgumentError("Invalid value for isolation_level. " + "Valid isolation levels for sqlite are 'SERIALIZABLE' and " + "'READ UNCOMMITTED'.") + self.isolation_level = isolation_level + + def visit_pool(self, pool): + if self.isolation_level is not None: + class SetIsolationLevel(object): + def __init__(self, isolation_level): + if isolation_level == 'READ UNCOMMITTED': + self.isolation_level = 1 + else: + self.isolation_level = 0 + + def connect(self, conn, rec): + cursor = conn.cursor() + cursor.execute("PRAGMA read_uncommitted = %d" % self.isolation_level) + cursor.close() + pool.add_listener(SetIsolationLevel(self.isolation_level)) + + def table_names(self, connection, schema): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='table' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT name FROM sqlite_master " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + def has_table(self, connection, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) + row = cursor.fetchone() + + # consume remaining rows, to work around + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + while cursor.fetchone() is not None: + pass + + return (row is not None) + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + return self.table_names(connection, schema) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='view' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='view' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT name FROM sqlite_master " + "WHERE type='view' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT sql FROM %s WHERE name = '%s'" + "AND type='view'") % (master, view_name) + rs = connection.execute(s) + else: + try: + s = ("SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type='view'") % view_name + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type='view'") % view_name + rs = connection.execute(s) + + result = rs.fetchall() + if result: + return result[0].sql + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) + found_table = False + columns = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5]) + name = re.sub(r'^\"|\"$', '', name) + if default: + default = re.sub(r"^\'|\'$", '', default) + match = re.match(r'(\w+)(\(.*?\))?', type_) + if match: + coltype = match.group(1) + args = match.group(2) + else: + coltype = "VARCHAR" + args = '' + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, name)) + coltype = sqltypes.NullType + if args is not None: + args = re.findall(r'(\d+)', args) + coltype = coltype(*[int(a) for a in args]) + + columns.append({ + 'name' : name, + 'type' : coltype, + 'nullable' : nullable, + 'default' : default, + 'primary_key': primary_key + }) + return columns + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + cols = self.get_columns(connection, table_name, schema, **kw) + pkeys = [] + for col in cols: + if col['primary_key']: + pkeys.append(col['name']) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))) + fkeys = [] + fks = {} + while True: + row = c.fetchone() + if row is None: + break + (constraint_name, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) + rtbl = re.sub(r'^\"|\"$', '', rtbl) + lcol = re.sub(r'^\"|\"$', '', lcol) + rcol = re.sub(r'^\"|\"$', '', rcol) + try: + fk = fks[constraint_name] + except KeyError: + fk = { + 'name' : constraint_name, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : rtbl, + 'referred_columns' : [] + } + fkeys.append(fk) + fks[constraint_name] = fk + + # look up the table based on the given table's engine, not 'self', + # since it could be a ProxyEngine + if lcol not in fk['constrained_columns']: + fk['constrained_columns'].append(lcol) + if rcol not in fk['referred_columns']: + fk['referred_columns'].append(rcol) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable))) + indexes = [] + while True: + row = c.fetchone() + if row is None: + break + indexes.append(dict(name=row[1], column_names=[], unique=row[2])) + # loop thru unique indexes to get the column names. + for idx in indexes: + c = connection.execute("%sindex_info(%s)" % (pragma, quote(idx['name']))) + cols = idx['column_names'] + while True: + row = c.fetchone() + if row is None: + break + cols.append(row[2]) + return indexes + + +def _pragma_cursor(cursor): + """work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows.""" + + if cursor.closed: + cursor._fetchone_impl = lambda: None + return cursor diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py new file mode 100644 index 000000000..a1873f33a --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -0,0 +1,174 @@ +"""Support for the SQLite database via pysqlite. + +Note that pysqlite is the same driver as the ``sqlite3`` +module included with the Python distribution. + +Driver +------ + +When using Python 2.5 and above, the built in ``sqlite3`` driver is +already installed and no additional installation is needed. Otherwise, +the ``pysqlite2`` driver needs to be present. This is the same driver as +``sqlite3``, just with a different name. + +The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3`` +is loaded. This allows an explicitly installed pysqlite driver to take +precedence over the built in one. As with all dialects, a specific +DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control +this explicitly:: + + from sqlite3 import dbapi2 as sqlite + e = create_engine('sqlite+pysqlite:///file.db', module=sqlite) + +Full documentation on pysqlite is available at: +`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_ + +Connect Strings +--------------- + +The file specification for the SQLite database is taken as the "database" portion of +the URL. Note that the format of a url is:: + + driver://user:pass@host/database + +This means that the actual filename to be used starts with the characters to the +**right** of the third slash. So connecting to a relative filepath looks like:: + + # relative path + e = create_engine('sqlite:///path/to/database.db') + +An absolute path, which is denoted by starting with a slash, means you need **four** +slashes:: + + # absolute path + e = create_engine('sqlite:////path/to/database.db') + +To use a Windows path, regular drive specifications and backslashes can be used. +Double backslashes are probably needed:: + + # absolute path on Windows + e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') + +The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify +``sqlite://`` and nothing else:: + + # in-memory database + e = create_engine('sqlite://') + +Threading Behavior +------------------ + +Pysqlite connections do not support being moved between threads, unless +the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition, +when using an in-memory SQLite database, the full database exists only within +the scope of a single connection. It is reported that an in-memory +database does not support being shared between threads regardless of the +``check_same_thread`` flag - which means that a multithreaded +application **cannot** share data from a ``:memory:`` database across threads +unless access to the connection is limited to a single worker thread which communicates +through a queueing mechanism to concurrent threads. + +To provide a default which accomodates SQLite's default threading capabilities +somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool` +be used by default. This pool maintains a single SQLite connection per thread +that is held open up to a count of five concurrent threads. When more than five threads +are used, a cleanup mechanism will dispose of excess unused connections. + +Two optional pool implementations that may be appropriate for particular SQLite usage scenarios: + + * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded + application using an in-memory database, assuming the threading issues inherent in + pysqlite are somehow accomodated for. This pool holds persistently onto a single connection + which is never closed, and is returned for all requests. + + * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that + makes use of a file-based sqlite database. This pool disables any actual "pooling" + behavior, and simply opens and closes real connections corresonding to the :func:`connect()` + and :func:`close()` methods. SQLite can "connect" to a particular file with very high + efficiency, so this option may actually perform better without the extra overhead + of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection + useless since the database would be lost as soon as the connection is "returned" to the pool. + +Unicode +------- + +In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's +default behavior regarding Unicode is that all strings are returned as Python unicode objects +in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is +*not* used, you will still always receive unicode data back from a result set. It is +**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type +to represent strings, since it will raise a warning if a non-unicode Python string is +passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can +quickly create confusion, particularly when using the ORM as internal data is not +always represented by an actual database result string. + +""" + +from sqlalchemy.dialects.sqlite.base import SQLiteDialect +from sqlalchemy import schema, exc, pool +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes +from sqlalchemy import util + +class SQLite_pysqlite(SQLiteDialect): + default_paramstyle = 'qmark' + poolclass = pool.SingletonThreadPool + + # Py3K + #description_encoding = None + + driver = 'pysqlite' + + def __init__(self, **kwargs): + SQLiteDialect.__init__(self, **kwargs) + def vers(num): + return tuple([int(x) for x in num.split('.')]) + if self.dbapi is not None: + sqlite_ver = self.dbapi.version_info + if sqlite_ver < (2, 1, '3'): + util.warn( + ("The installed version of pysqlite2 (%s) is out-dated " + "and will cause errors in some cases. Version 2.1.3 " + "or greater is recommended.") % + '.'.join([str(subver) for subver in sqlite_ver])) + if self.dbapi.sqlite_version_info < (3, 3, 8): + self.supports_default_values = False + self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) + + @classmethod + def dbapi(cls): + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError, e: + try: + from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. + except ImportError: + raise e + return sqlite + + def _get_server_version_info(self, connection): + return self.dbapi.sqlite_version_info + + def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exc.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,)) + filename = url.database or ':memory:' + + opts = url.query.copy() + util.coerce_kw_type(opts, 'timeout', float) + util.coerce_kw_type(opts, 'isolation_level', str) + util.coerce_kw_type(opts, 'detect_types', int) + util.coerce_kw_type(opts, 'check_same_thread', bool) + util.coerce_kw_type(opts, 'cached_statements', int) + + return ([filename], opts) + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) + +dialect = SQLite_pysqlite diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py new file mode 100644 index 000000000..f8baf339e --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.sybase import base, pyodbc + +# default dialect +base.dialect = pyodbc.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py new file mode 100644 index 000000000..6f8c648e4 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -0,0 +1,458 @@ +# sybase.py +# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch +# Coding: Alexander Houben alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the Sybase iAnywhere database. + +This is not a full backend for Sybase ASE. + +This dialect is *not* tested on SQLAlchemy 0.6. + + +Known issues / TODO: + + * Uses the mx.ODBC driver from egenix (version 2.1.0) + * The current version of sqlalchemy.databases.sybase only supports + mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need + some development) + * Support for pyodbc has been built in but is not yet complete (needs + further development) + * Results of running tests/alltests.py: + Ran 934 tests in 287.032s + FAILED (failures=3, errors=1) + * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) +""" + +import datetime, operator + +from sqlalchemy import util, sql, schema, exc +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import MetaData, Table, Column +from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey +from sqlalchemy.dialects.sybase.schema import * + +RESERVED_WORDS = set([ + "add", "all", "alter", "and", + "any", "as", "asc", "backup", + "begin", "between", "bigint", "binary", + "bit", "bottom", "break", "by", + "call", "capability", "cascade", "case", + "cast", "char", "char_convert", "character", + "check", "checkpoint", "close", "comment", + "commit", "connect", "constraint", "contains", + "continue", "convert", "create", "cross", + "cube", "current", "current_timestamp", "current_user", + "cursor", "date", "dbspace", "deallocate", + "dec", "decimal", "declare", "default", + "delete", "deleting", "desc", "distinct", + "do", "double", "drop", "dynamic", + "else", "elseif", "encrypted", "end", + "endif", "escape", "except", "exception", + "exec", "execute", "existing", "exists", + "externlogin", "fetch", "first", "float", + "for", "force", "foreign", "forward", + "from", "full", "goto", "grant", + "group", "having", "holdlock", "identified", + "if", "in", "index", "index_lparen", + "inner", "inout", "insensitive", "insert", + "inserting", "install", "instead", "int", + "integer", "integrated", "intersect", "into", + "iq", "is", "isolation", "join", + "key", "lateral", "left", "like", + "lock", "login", "long", "match", + "membership", "message", "mode", "modify", + "natural", "new", "no", "noholdlock", + "not", "notify", "null", "numeric", + "of", "off", "on", "open", + "option", "options", "or", "order", + "others", "out", "outer", "over", + "passthrough", "precision", "prepare", "primary", + "print", "privileges", "proc", "procedure", + "publication", "raiserror", "readtext", "real", + "reference", "references", "release", "remote", + "remove", "rename", "reorganize", "resource", + "restore", "restrict", "return", "revoke", + "right", "rollback", "rollup", "save", + "savepoint", "scroll", "select", "sensitive", + "session", "set", "setuser", "share", + "smallint", "some", "sqlcode", "sqlstate", + "start", "stop", "subtrans", "subtransaction", + "synchronize", "syntax_error", "table", "temporary", + "then", "time", "timestamp", "tinyint", + "to", "top", "tran", "trigger", + "truncate", "tsequal", "unbounded", "union", + "unique", "unknown", "unsigned", "update", + "updating", "user", "using", "validate", + "values", "varbinary", "varchar", "variable", + "varying", "view", "wait", "waitfor", + "when", "where", "while", "window", + "with", "with_cube", "with_lparen", "with_rollup", + "within", "work", "writetext", + ]) + + +class SybaseImage(sqltypes.Binary): + __visit_name__ = 'IMAGE' + +class SybaseBit(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + +class SybaseMoney(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + +class SybaseSmallMoney(SybaseMoney): + __visit_name__ = "SMALLMONEY" + +class SybaseUniqueIdentifier(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class SybaseBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class SybaseTypeCompiler(compiler.GenericTypeCompiler): + def visit_binary(self, type_): + return self.visit_IMAGE(type_) + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return "SMALLMONEY" + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + +colspecs = { + sqltypes.Binary : SybaseImage, + sqltypes.Boolean : SybaseBoolean, +} + +ischema_names = { + 'integer' : sqltypes.INTEGER, + 'unsigned int' : sqltypes.Integer, + 'unsigned smallint' : sqltypes.SmallInteger, + 'unsigned bigint' : sqltypes.BigInteger, + 'bigint': sqltypes.BIGINT, + 'smallint' : sqltypes.SMALLINT, + 'tinyint' : sqltypes.SmallInteger, + 'varchar' : sqltypes.VARCHAR, + 'long varchar' : sqltypes.Text, + 'char' : sqltypes.CHAR, + 'decimal' : sqltypes.DECIMAL, + 'numeric' : sqltypes.NUMERIC, + 'float' : sqltypes.FLOAT, + 'double' : sqltypes.Numeric, + 'binary' : sqltypes.Binary, + 'long binary' : sqltypes.Binary, + 'varbinary' : sqltypes.Binary, + 'bit': SybaseBit, + 'image' : SybaseImage, + 'timestamp': sqltypes.TIMESTAMP, + 'money': SybaseMoney, + 'smallmoney': SybaseSmallMoney, + 'uniqueidentifier': SybaseUniqueIdentifier, + +} + + +class SybaseExecutionContext(default.DefaultExecutionContext): + + def post_exec(self): + if self.compiled.isinsert: + table = self.compiled.statement.table + # get the inserted values of the primary key + + # get any sequence IDs first (using @@identity) + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + lastrowid = int(row[0]) + if lastrowid > 0: + # an IDENTITY was inserted, fetch it + # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! + if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: + self._last_inserted_ids = [lastrowid] + else: + self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] + + +class SybaseSQLCompiler(compiler.SQLCompiler): + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond' + }) + + def visit_mod(self, binary, **kw): + return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def bindparam_string(self, name): + res = super(SybaseSQLCompiler, self).bindparam_string(name) + if name.lower().startswith('literal'): + res = 'STRING(%s)' % res + return res + + def get_select_precolumns(self, select): + s = select._distinct and "DISTINCT " or "" + if select._limit: + #if select._limit == 1: + #s += "FIRST " + #else: + #s += "TOP %s " % (select._limit,) + s += "TOP %s " % (select._limit,) + if select._offset: + if not select._limit: + # FIXME: sybase doesn't allow an offset without a limit + # so use a huge value for TOP here + s += "TOP 1000000 " + s += "START AT %s " % (select._offset+1,) + return s + + def limit_clause(self, select): + # Limit in sybase is after the select keyword + return "" + + def visit_binary(self, binary): + """Move bind parameters to the right-hand side of an operator, where possible.""" + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(SybaseSQLCompiler, self).visit_binary(binary) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) + + function_rewrites = {'current_date': 'getdate', + } + def visit_function(self, func): + func.name = self.function_rewrites.get(func.name, func.name) + res = super(SybaseSQLCompiler, self).visit_function(func) + if func.name.lower() == 'getdate': + # apply CAST operator + # FIXME: what about _pyodbc ? + cast = expression._Cast(func, SybaseDate_mxodbc) + # infinite recursion + # res = self.visit_cast(cast) + res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) + return res + + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class SybaseDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + + colspec = self.preparer.format_column(column) + + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, sqltypes.Integer): + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + #colspec += " numeric(30,0) IDENTITY" + colspec += " Integer IDENTITY" + else: + colspec += " " + self.dialect.type_compiler.process(column.type) + + if not column.nullable: + colspec += " NOT NULL" + + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(index.table.name), + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) + ) + +class SybaseIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + +class SybaseDialect(default.DefaultDialect): + name = 'sybase' + supports_unicode_statements = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + colspecs = colspecs + ischema_names = ischema_names + + type_compiler = SybaseTypeCompiler + statement_compiler = SybaseSQLCompiler + ddl_compiler = SybaseDDLCompiler + preparer = SybaseIdentifierPreparer + + schema_name = "dba" + + def __init__(self, **params): + super(SybaseDialect, self).__init__(**params) + self.text_as_varchar = False + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def get_default_schema_name(self, connection): + return self.schema_name + + def table_names(self, connection, schema): + """Ignore the schema and the charset for now.""" + s = sql.select([tables.c.table_name], + sql.not_(tables.c.table_name.like("SYS%")) and + tables.c.creator >= 100 + ) + rp = connection.execute(s) + return [row[0] for row in rp.fetchall()] + + def has_table(self, connection, tablename, schema=None): + # FIXME: ignore schemas for sybase + s = sql.select([tables.c.table_name], tables.c.table_name == tablename) + + c = connection.execute(s) + row = c.fetchone() + return row is not None + + def reflecttable(self, connection, table, include_columns): + # Get base columns + if table.schema is not None: + current_schema = table.schema + else: + current_schema = self.get_default_schema_name(connection) + + s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) + + c = connection.execute(s) + found_table = False + # makes sure we append the columns in the correct order + while True: + row = c.fetchone() + if row is None: + break + found_table = True + (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( + row[columns.c.column_name], + row[domains.c.domain_name], + row[columns.c.nulls] == 'Y', + row[columns.c.width], + row[domains.c.precision], + row[columns.c.scale], + row[columns.c.default], + row[columns.c.pkey] == 'Y', + row[columns.c.max_identity], + row[tables.c.table_id], + row[columns.c.column_id], + ) + if include_columns and name not in include_columns: + continue + + # FIXME: else problems with SybaseBinary(size) + if numericscale == 0: + numericscale = None + + args = [] + for a in (charlen, numericprec, numericscale): + if a is not None: + args.append(a) + coltype = self.ischema_names.get(type, None) + if coltype == SybaseString and charlen == -1: + coltype = SybaseText() + else: + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % + (type, name)) + coltype = sqltypes.NULLTYPE + coltype = coltype(*args) + colargs = [] + if default is not None: + colargs.append(schema.DefaultClause(sql.text(default))) + + # any sequences ? + col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) + if int(max_identity) > 0: + col.sequence = schema.Sequence(name + '_identity') + col.sequence.start = int(max_identity) + col.sequence.increment = 1 + + # append the column + table.append_column(col) + + # any foreign key constraint for this table ? + # note: no multi-column foreign keys are considered + s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } + c = connection.execute(s) + foreignKeys = {} + while True: + row = c.fetchone() + if row is None: + break + (foreign_table, foreign_column, primary_table, primary_column) = ( + row[0], row[1], row[2], row[3], + ) + if not primary_table in foreignKeys.keys(): + foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]] + else: + foreignKeys[primary_table][0].append('%s'%(foreign_column)) + foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column)) + for primary_table in foreignKeys.iterkeys(): + #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) + table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True)) + + if not found_table: + raise exc.NoSuchTableError(table.name) + diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py new file mode 100644 index 000000000..86a23d5bc --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -0,0 +1,10 @@ +from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext +from sqlalchemy.connectors.mxodbc import MxODBCConnector + +class SybaseExecutionContext_mxodbc(SybaseExecutionContext): + pass + +class Sybase_mxodbc(MxODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_mxodbc + +dialect = Sybase_mxodbc
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py new file mode 100644 index 000000000..61c6f3292 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -0,0 +1,11 @@ +from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector + +class SybaseExecutionContext_pyodbc(SybaseExecutionContext): + pass + + +class Sybase_pyodbc(PyODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_pyodbc + +dialect = Sybase_pyodbc
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/schema.py b/lib/sqlalchemy/dialects/sybase/schema.py new file mode 100644 index 000000000..15ac6b27b --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/schema.py @@ -0,0 +1,51 @@ +from sqlalchemy import * + +ischema = MetaData() + +tables = Table("SYSTABLE", ischema, + Column("table_id", Integer, primary_key=True), + Column("file_id", SMALLINT), + Column("table_name", CHAR(128)), + Column("table_type", CHAR(10)), + Column("creator", Integer), + #schema="information_schema" + ) + +domains = Table("SYSDOMAIN", ischema, + Column("domain_id", Integer, primary_key=True), + Column("domain_name", CHAR(128)), + Column("type_id", SMALLINT), + Column("precision", SMALLINT, quote=True), + #schema="information_schema" + ) + +columns = Table("SYSCOLUMN", ischema, + Column("column_id", Integer, primary_key=True), + Column("table_id", Integer, ForeignKey(tables.c.table_id)), + Column("pkey", CHAR(1)), + Column("column_name", CHAR(128)), + Column("nulls", CHAR(1)), + Column("width", SMALLINT), + Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)), + # FIXME: should be mx.BIGINT + Column("max_identity", Integer), + # FIXME: should be mx.ODBC.Windows.LONGVARCHAR + Column("default", String), + Column("scale", Integer), + #schema="information_schema" + ) + +foreignkeys = Table("SYSFOREIGNKEY", ischema, + Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, primary_key=True), + Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)), + #schema="information_schema" + ) +fkcols = Table("SYSFKCOL", ischema, + Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True), + Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True), + Column("primary_column_id", Integer), + #schema="information_schema" + ) + diff --git a/lib/sqlalchemy/dialects/type_migration_guidelines.txt b/lib/sqlalchemy/dialects/type_migration_guidelines.txt new file mode 100644 index 000000000..8ed1a1797 --- /dev/null +++ b/lib/sqlalchemy/dialects/type_migration_guidelines.txt @@ -0,0 +1,145 @@ +Rules for Migrating TypeEngine classes to 0.6 +--------------------------------------------- + +1. the TypeEngine classes are used for: + + a. Specifying behavior which needs to occur for bind parameters + or result row columns. + + b. Specifying types that are entirely specific to the database + in use and have no analogue in the sqlalchemy.types package. + + c. Specifying types where there is an analogue in sqlalchemy.types, + but the database in use takes vendor-specific flags for those + types. + + d. If a TypeEngine class doesn't provide any of this, it should be + *removed* from the dialect. + +2. the TypeEngine classes are *no longer* used for generating DDL. Dialects +now have a TypeCompiler subclass which uses the same visit_XXX model as +other compilers. + +3. the "ischema_names" and "colspecs" dictionaries are now required members on +the Dialect class. + +4. The names of types within dialects are now important. If a dialect-specific type +is a subclass of an existing generic type and is only provided for bind/result behavior, +the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case, +end users would never need to use _PGNumeric directly. However, if a dialect-specific +type is specifying a type *or* arguments that are not present generically, it should +match the real name of the type on that backend, in uppercase. E.g. postgresql.INET, +mysql.ENUM, postgresql.ARRAY. + +Or follow this handy flowchart: + + is the type meant to provide bind/result is the type the same name as an + behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ? + type in types.py ? | | + | no yes + yes | | + | | does your type need special + | +<--- yes --- behavior or arguments ? + | | | + | | no + name the type using | | + _MixedCase, i.e. v V + _OracleBoolean. it name the type don't make a + stays private to the dialect identically as that type, make sure the dialect's + and is invoked *only* via within the DB, base.py imports the types.py + the colspecs dict. using UPPERCASE UPPERCASE name into its namespace + | (i.e. BIT, NCHAR, INTERVAL). + | Users can import it. + | | + v v + subclass the closest is the name of this type + MixedCase type types.py, identical to an UPPERCASE + i.e. <--- no ------- name in types.py ? + class _DateTime(types.DateTime), + class DATETIME2(types.DateTime), | + class BIT(types.TypeEngine). yes + | + v + the type should + subclass the + UPPERCASE + type in types.py + (i.e. class BLOB(types.BLOB)) + + +Example 1. pysqlite needs bind/result processing for the DateTime type in types.py, +which applies to all DateTimes and subclasses. It's named _SLDateTime and +subclasses types.DateTime. + +Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument +that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py, +and subclasses types.TIME. Users can then say mssql.TIME(precision=10). + +Example 3. MS-SQL dialects also need special bind/result processing for date +But its DATE type doesn't render DDL differently than that of a plain +DATE, i.e. it takes no special arguments. Therefore we are just adding behavior +to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses +types.Date. + +Example 4. MySQL has a SET type, there's no analogue for this in types.py. So +MySQL names it SET in the dialect's base.py, and it subclasses types.String, since +it ultimately deals with strings. + +Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly, +and no special arguments are used in PG's DDL beyond what types.py provides. +Postgresql dialect therefore imports types.DATETIME into its base.py. + +Ideally one should be able to specify a schema using names imported completely from a +dialect, all matching the real name on that backend: + + from sqlalchemy.dialects.postgresql import base as pg + + t = Table('mytable', metadata, + Column('id', pg.INTEGER, primary_key=True), + Column('name', pg.VARCHAR(300)), + Column('inetaddr', pg.INET) + ) + +where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types, +but the PG dialect makes them available in its own namespace. + +5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types +linked to types specified in the dialect. Again, if a type in the dialect does not +specify any special behavior for bind_processor() or result_processor() and does not +indicate a special type only available in this database, it must be *removed* from the +module and from this dictionary. + +6. "ischema_names" indicates string descriptions of types as returned from the database +linked to TypeEngine classes. + + a. The string name should be matched to the most specific type possible within + sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which + case it points to a dialect type. *It doesn't matter* if the dialect has it's + own subclass of that type with special bind/result behavior - reflect to the types.py + UPPERCASE type as much as possible. With very few exceptions, all types + should reflect to an UPPERCASE type. + + b. If the dialect contains a matching dialect-specific type that takes extra arguments + which the generic one does not, then point to the dialect-specific type. E.g. + mssql.VARCHAR takes a "collation" parameter which should be preserved. + +5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by +a subclass of compiler.GenericTypeCompiler. + + a. your TypeCompiler class will receive generic and uppercase types from + sqlalchemy.types. Do not assume the presence of dialect-specific attributes on + these types. + + b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with + methods that produce a different DDL name. Uppercase types don't do any kind of + "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in + all cases, regardless of whether or not that type is legal on the backend database. + + c. the visit_UPPERCASE methods *should* be overridden with methods that add additional + arguments and flags to those types. + + d. the visit_lowercase methods are overridden to provide an interpretation of a generic + type. E.g. visit_binary() might be overridden to say "return self.visit_BIT(type_)". + + e. visit_lowercase methods should *never* render strings directly - it should always + be via calling a visit_UPPERCASE() method. |
