diff options
Diffstat (limited to 'lib/sqlalchemy')
51 files changed, 2912 insertions, 1244 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 853566172..d184e1fbf 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -25,6 +25,7 @@ from .sql import ( extract, false, func, + funcfilter, insert, intersect, intersect_all, diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index ba3050ae5..dad02ee0f 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -846,7 +846,7 @@ class MSExecutionContext(default.DefaultExecutionContext): "SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer. format_table( self.compiled.statement.table))) - except: + except Exception: pass def get_result_proxy(self): diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 8f76336ae..b5a1bc566 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -63,7 +63,7 @@ class MSDialect_pymssql(MSDialect): def _get_server_version_info(self, connection): vers = connection.scalar("select @@version") m = re.match( - r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers) + r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers) if m: return tuple(int(x) for x in m.group(1, 2, 3, 4)) else: diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7ccd59abb..2fb054d0c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -341,6 +341,29 @@ reflection will not include foreign keys. For these tables, you may supply a :ref:`mysql_storage_engines` +.. _mysql_unique_constraints: + +MySQL Unique Constraints and Reflection +--------------------------------------- + +SQLAlchemy supports both the :class:`.Index` construct with the +flag ``unique=True``, indicating a UNIQUE index, as well as the +:class:`.UniqueConstraint` construct, representing a UNIQUE constraint. +Both objects/syntaxes are supported by MySQL when emitting DDL to create +these constraints. However, MySQL does not have a unique constraint +construct that is separate from a unique index; that is, the "UNIQUE" +constraint on MySQL is equivalent to creating a "UNIQUE INDEX". + +When reflecting these constructs, the :meth:`.Inspector.get_indexes` +and the :meth:`.Inspector.get_unique_constraints` methods will **both** +return an entry for a UNIQUE index in MySQL. However, when performing +full table reflection using ``Table(..., autoload=True)``, +the :class:`.UniqueConstraint` construct is +**not** part of the fully reflected :class:`.Table` construct under any +circumstances; this construct is always represented by a :class:`.Index` +with the ``unique=True`` setting present in the :attr:`.Table.indexes` +collection. + .. _mysql_timestamp_null: @@ -2317,7 +2340,7 @@ class MySQLDialect(default.DefaultDialect): # basic operations via autocommit fail. try: dbapi_connection.commit() - except: + except Exception: if self.server_version_info < (3, 23, 15): args = sys.exc_info()[1].args if args and args[0] == 1064: @@ -2329,7 +2352,7 @@ class MySQLDialect(default.DefaultDialect): try: dbapi_connection.rollback() - except: + except Exception: if self.server_version_info < (3, 23, 15): args = sys.exc_info()[1].args if args and args[0] == 1064: @@ -2590,7 +2613,8 @@ class MySQLDialect(default.DefaultDialect): return [ { 'name': key['name'], - 'column_names': [col[0] for col in key['columns']] + 'column_names': [col[0] for col in key['columns']], + 'duplicates_index': key['name'], } for key in parsed_state.keys if key['type'] == 'UNIQUE' diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index e51e80005..417e1ad6f 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -21,6 +21,7 @@ from .base import (MySQLDialect, MySQLExecutionContext, BIT) from ... import util +import re class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): @@ -31,18 +32,34 @@ class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): class MySQLCompiler_mysqlconnector(MySQLCompiler): def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + if self.dialect._mysqlconnector_double_percents: + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + else: + return self.process(binary.left, **kw) + " % " + \ + self.process(binary.right, **kw) def post_process_text(self, text): - return text.replace('%', '%%') + if self.dialect._mysqlconnector_double_percents: + return text.replace('%', '%%') + else: + return text + + def escape_literal_column(self, text): + if self.dialect._mysqlconnector_double_percents: + return text.replace('%', '%%') + else: + return text class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): def _escape_identifier(self, value): value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace("%", "%%") + if self.dialect._mysqlconnector_double_percents: + return value.replace("%", "%%") + else: + return value class _myconnpyBIT(BIT): @@ -55,8 +72,6 @@ class _myconnpyBIT(BIT): class MySQLDialect_mysqlconnector(MySQLDialect): driver = 'mysqlconnector' - if util.py2k: - supports_unicode_statements = False supports_unicode_binds = True supports_sane_rowcount = True @@ -77,6 +92,10 @@ class MySQLDialect_mysqlconnector(MySQLDialect): } ) + @util.memoized_property + def supports_unicode_statements(self): + return util.py3k or self._mysqlconnector_version_info > (2, 0) + @classmethod def dbapi(cls): from mysql import connector @@ -103,10 +122,25 @@ class MySQLDialect_mysqlconnector(MySQLDialect): 'client_flags', ClientFlag.get_default()) client_flags |= ClientFlag.FOUND_ROWS opts['client_flags'] = client_flags - except: + except Exception: pass return [[], opts] + @util.memoized_property + def _mysqlconnector_version_info(self): + if self.dbapi and hasattr(self.dbapi, '__version__'): + m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', + self.dbapi.__version__) + if m: + return tuple( + int(x) + for x in m.group(1, 2, 3) + if x is not None) + + @util.memoized_property + def _mysqlconnector_double_percents(self): + return not util.py3k and self._mysqlconnector_version_info < (2, 0) + def _get_server_version_info(self, connection): dbapi_con = connection.connection version = dbapi_con.get_server_version() diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 81a9f1a95..6df38e57e 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -213,6 +213,21 @@ is reflected and the type is reported as ``DATE``, the time-supporting examining the type of column for use in special Python translations or for migrating schemas to other database backends. +Oracle Table Options +------------------------- + +The CREATE TABLE phrase supports the following options with Oracle +in conjunction with the :class:`.Table` construct: + + +* ``ON COMMIT``:: + + Table( + "some_table", metadata, ..., + prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS') + +.. versionadded:: 1.0.0 + """ import re @@ -784,11 +799,22 @@ class OracleDDLCompiler(compiler.DDLCompiler): return super(OracleDDLCompiler, self).\ visit_create_index(create, include_schema=True) + def post_create_table(self, table): + table_opts = [] + opts = table.dialect_options['oracle'] + + if opts['on_commit']: + on_commit_options = opts['on_commit'].replace("_", " ").upper() + table_opts.append('\n ON COMMIT %s' % on_commit_options) + + return ''.join(table_opts) + class OracleIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = set([x.lower() for x in RESERVED_WORDS]) - illegal_initial_characters = set(range(0, 10)).union(["_", "$"]) + illegal_initial_characters = set( + (str(dig) for dig in range(0, 10))).union(["_", "$"]) def _bindparam_requires_quotes(self, value): """Return True if the given identifier requires quoting.""" @@ -842,7 +868,10 @@ class OracleDialect(default.DefaultDialect): reflection_options = ('oracle_resolve_synonyms', ) construct_arguments = [ - (sa_schema.Table, {"resolve_synonyms": False}) + (sa_schema.Table, { + "resolve_synonyms": False, + "on_commit": None + }) ] def __init__(self, @@ -1029,7 +1058,21 @@ class OracleDialect(default.DefaultDialect): "WHERE nvl(tablespace_name, 'no tablespace') NOT IN " "('SYSTEM', 'SYSAUX') " "AND OWNER = :owner " - "AND IOT_NAME IS NULL") + "AND IOT_NAME IS NULL " + "AND DURATION IS NULL") + cursor = connection.execute(s, owner=schema) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + schema = self.denormalize_name(self.default_schema_name) + s = sql.text( + "SELECT table_name FROM all_tables " + "WHERE nvl(tablespace_name, 'no tablespace') NOT IN " + "('SYSTEM', 'SYSAUX') " + "AND OWNER = :owner " + "AND IOT_NAME IS NULL " + "AND DURATION IS NOT NULL") cursor = connection.execute(s, owner=schema) return [self.normalize_name(row[0]) for row in cursor] diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 575d2a6dd..baa640eaa 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -401,6 +401,29 @@ The value passed to the keyword argument will be simply passed through to the underlying CREATE INDEX command, so it *must* be a valid index type for your version of PostgreSQL. + +.. _postgresql_index_reflection: + +Postgresql Index Reflection +--------------------------- + +The Postgresql database creates a UNIQUE INDEX implicitly whenever the +UNIQUE CONSTRAINT construct is used. When inspecting a table using +:class:`.Inspector`, the :meth:`.Inspector.get_indexes` +and the :meth:`.Inspector.get_unique_constraints` will report on these +two constructs distinctly; in the case of the index, the key +``duplicates_constraint`` will be present in the index entry if it is +detected as mirroring a constraint. When performing reflection using +``Table(..., autoload=True)``, the UNIQUE INDEX is **not** returned +in :attr:`.Table.indexes` when it is detected as mirroring a +:class:`.UniqueConstraint` in the :attr:`.Table.constraints` collection. + +.. versionchanged:: 1.0.0 - :class:`.Table` reflection now includes + :class:`.UniqueConstraint` objects present in the :attr:`.Table.constraints` + collection; the Postgresql backend will no longer include a "mirrored" + :class:`.Index` construct in :attr:`.Table.indexes` if it is detected + as corresponding to a unique constraint. + Special Reflection Options -------------------------- @@ -1679,6 +1702,19 @@ class PGInspector(reflection.Inspector): schema = schema or self.default_schema_name return self.dialect._load_enums(self.bind, schema) + def get_foreign_table_names(self, schema=None): + """Return a list of FOREIGN TABLE names. + + Behavior is similar to that of :meth:`.Inspector.get_table_names`, + except that the list is limited to those tables tha report a + ``relkind`` value of ``f``. + + .. versionadded:: 1.0.0 + + """ + schema = schema or self.default_schema_name + return self.dialect._get_foreign_table_names(self.bind, schema) + class CreateEnumType(schema._CreateDropBase): __visit_name__ = "create_enum_type" @@ -2024,7 +2060,7 @@ class PGDialect(default.DefaultDialect): FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE (%s) - AND c.relname = :table_name AND c.relkind in ('r','v') + AND c.relname = :table_name AND c.relkind in ('r', 'v', 'm', 'f') """ % schema_where_clause # Since we're binding to unicode, table_name and schema_name must be # unicode. @@ -2078,6 +2114,24 @@ class PGDialect(default.DefaultDialect): return [row[0] for row in result] @reflection.cache + def _get_foreign_table_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + result = connection.execute( + sql.text("SELECT relname FROM pg_class c " + "WHERE relkind = 'f' " + "AND '%s' = (select nspname from pg_namespace n " + "where n.oid = c.relnamespace) " % + current_schema, + typemap={'relname': sqltypes.Unicode} + ) + ) + return [row[0] for row in result] + + @reflection.cache def get_view_names(self, connection, schema=None, **kw): if schema is not None: current_schema = schema @@ -2086,7 +2140,7 @@ class PGDialect(default.DefaultDialect): s = """ SELECT relname FROM pg_class c - WHERE relkind = 'v' + WHERE relkind IN ('m', 'v') AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) """ % dict(schema=current_schema) @@ -2439,16 +2493,21 @@ class PGDialect(default.DefaultDialect): SELECT i.relname as relname, ix.indisunique, ix.indexprs, ix.indpred, - a.attname, a.attnum, ix.indkey%s + a.attname, a.attnum, c.conrelid, ix.indkey%s FROM pg_class t join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid=ix.indexrelid + join pg_class i on i.oid = ix.indexrelid left outer join pg_attribute a - on t.oid=a.attrelid and %s + on t.oid = a.attrelid and %s + left outer join + pg_constraint c + on (ix.indrelid = c.conrelid and + ix.indexrelid = c.conindid and + c.contype in ('p', 'u', 'x')) WHERE - t.relkind = 'r' + t.relkind IN ('r', 'v', 'f', 'm') and t.oid = :table_oid and ix.indisprimary = 'f' ORDER BY @@ -2469,7 +2528,7 @@ class PGDialect(default.DefaultDialect): sv_idx_name = None for row in c.fetchall(): - idx_name, unique, expr, prd, col, col_num, idx_key = row + idx_name, unique, expr, prd, col, col_num, conrelid, idx_key = row if expr: if idx_name != sv_idx_name: @@ -2486,18 +2545,27 @@ class PGDialect(default.DefaultDialect): % idx_name) sv_idx_name = idx_name + has_idx = idx_name in indexes index = indexes[idx_name] if col is not None: index['cols'][col_num] = col - index['key'] = [int(k.strip()) for k in idx_key.split()] - index['unique'] = unique - - return [ - {'name': name, - 'unique': idx['unique'], - 'column_names': [idx['cols'][i] for i in idx['key']]} - for name, idx in indexes.items() - ] + if not has_idx: + index['key'] = [int(k.strip()) for k in idx_key.split()] + index['unique'] = unique + if conrelid is not None: + index['duplicates_constraint'] = idx_name + + result = [] + for name, idx in indexes.items(): + entry = { + 'name': name, + 'unique': idx['unique'], + 'column_names': [idx['cols'][i] for i in idx['key']] + } + if 'duplicates_constraint' in idx: + entry['duplicates_constraint'] = idx['duplicates_constraint'] + result.append(entry) + return result @reflection.cache def get_unique_constraints(self, connection, table_name, diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index e6450c97f..1a2a1ffe4 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -32,10 +32,25 @@ psycopg2-specific keyword arguments which are accepted by way of enabling this mode on a per-execution basis. * ``use_native_unicode``: Enable the usage of Psycopg2 "native unicode" mode per connection. True by default. + + .. seealso:: + + :ref:`psycopg2_disable_native_unicode` + * ``isolation_level``: This option, available for all PostgreSQL dialects, includes the ``AUTOCOMMIT`` isolation level when using the psycopg2 - dialect. See :ref:`psycopg2_isolation_level`. + dialect. + + .. seealso:: + + :ref:`psycopg2_isolation_level` + +* ``client_encoding``: sets the client encoding in a libpq-agnostic way, + using psycopg2's ``set_client_encoding()`` method. + + .. seealso:: + :ref:`psycopg2_unicode` Unix Domain Connections ------------------------ @@ -75,8 +90,10 @@ The following DBAPI-specific options are respected when used with If ``None`` or not set, the ``server_side_cursors`` option of the :class:`.Engine` is used. -Unicode -------- +.. _psycopg2_unicode: + +Unicode with Psycopg2 +---------------------- By default, the psycopg2 driver uses the ``psycopg2.extensions.UNICODE`` extension, such that the DBAPI receives and returns all strings as Python @@ -84,27 +101,51 @@ Unicode objects directly - SQLAlchemy passes these values through without change. Psycopg2 here will encode/decode string values based on the current "client encoding" setting; by default this is the value in the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. -Typically, this can be changed to ``utf-8``, as a more useful default:: +Typically, this can be changed to ``utf8``, as a more useful default:: + + # postgresql.conf file - #client_encoding = sql_ascii # actually, defaults to database + # client_encoding = sql_ascii # actually, defaults to database # encoding client_encoding = utf8 A second way to affect the client encoding is to set it within Psycopg2 -locally. SQLAlchemy will call psycopg2's ``set_client_encoding()`` -method (see: -http://initd.org/psycopg/docs/connection.html#connection.set_client_encoding) +locally. SQLAlchemy will call psycopg2's +:meth:`psycopg2:connection.set_client_encoding` method on all new connections based on the value passed to :func:`.create_engine` using the ``client_encoding`` parameter:: + # set_client_encoding() setting; + # works for *all* Postgresql versions engine = create_engine("postgresql://user:pass@host/dbname", client_encoding='utf8') This overrides the encoding specified in the Postgresql client configuration. +When using the parameter in this way, the psycopg2 driver emits +``SET client_encoding TO 'utf8'`` on the connection explicitly, and works +in all Postgresql versions. + +Note that the ``client_encoding`` setting as passed to :func:`.create_engine` +is **not the same** as the more recently added ``client_encoding`` parameter +now supported by libpq directly. This is enabled when ``client_encoding`` +is passed directly to ``psycopg2.connect()``, and from SQLAlchemy is passed +using the :paramref:`.create_engine.connect_args` parameter:: + + # libpq direct parameter setting; + # only works for Postgresql **9.1 and above** + engine = create_engine("postgresql://user:pass@host/dbname", + connect_args={'client_encoding': 'utf8'}) + + # using the query string is equivalent + engine = create_engine("postgresql://user:pass@host/dbname?client_encoding=utf8") + +The above parameter was only added to libpq as of version 9.1 of Postgresql, +so using the previous method is better for cross-version support. + +.. _psycopg2_disable_native_unicode: -.. versionadded:: 0.7.3 - The psycopg2-specific ``client_encoding`` parameter to - :func:`.create_engine`. +Disabling Native Unicode +^^^^^^^^^^^^^^^^^^^^^^^^ SQLAlchemy can also be instructed to skip the usage of the psycopg2 ``UNICODE`` extension and to instead utilize its own unicode encode/decode @@ -116,8 +157,7 @@ in and coerce from bytes on the way back, using the value of the :func:`.create_engine` ``encoding`` parameter, which defaults to ``utf-8``. SQLAlchemy's own unicode encode/decode functionality is steadily becoming -obsolete as more DBAPIs support unicode fully along with the approach of -Python 3; in modern usage psycopg2 should be relied upon to handle unicode. +obsolete as most DBAPIs now support unicode fully. Transactions ------------ @@ -512,12 +552,14 @@ class PGDialect_psycopg2(PGDialect): def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.Error): # check the "closed" flag. this might not be - # present on old psycopg2 versions + # present on old psycopg2 versions. Also, + # this flag doesn't actually help in a lot of disconnect + # situations, so don't rely on it. if getattr(connection, 'closed', False): return True - # legacy checks based on strings. the "closed" check - # above most likely obviates the need for any of these. + # checks based on strings. in the case that .closed + # didn't cut it, fall back onto these. str_e = str(e).partition("\n")[0] for msg in [ # these error messages from libpq: interfaces/libpq/fe-misc.c @@ -534,8 +576,10 @@ class PGDialect_psycopg2(PGDialect): # not sure where this path is originally from, it may # be obsolete. It really says "losed", not "closed". 'losed the connection unexpectedly', - # this can occur in newer SSL - 'connection has been closed unexpectedly' + # these can occur in newer SSL + 'connection has been closed unexpectedly', + 'SSL SYSCALL error: Bad file descriptor', + 'SSL SYSCALL error: EOF detected', ]: idx = str_e.find(msg) if idx >= 0 and '"' not in str_e[:idx]: diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py index 0eceaa537..a53d53e9d 100644 --- a/lib/sqlalchemy/dialects/sqlite/__init__.py +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -5,7 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy.dialects.sqlite import base, pysqlite +from sqlalchemy.dialects.sqlite import base, pysqlite, pysqlcipher # default dialect base.dialect = pysqlite.dialect diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index af793d275..335b35c94 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -713,10 +713,12 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): return self.execution_options.get("sqlite_raw_colnames", False) def _translate_colname(self, colname): - # adjust for dotted column names. SQLite in the case of UNION may - # store col names as "tablename.colname" in cursor.description + # adjust for dotted column names. SQLite + # in the case of UNION may store col names as + # "tablename.colname", or if using an attached database, + # "database.tablename.colname", in cursor.description if not self._preserve_raw_colnames and "." in colname: - return colname.split(".")[1], colname + return colname.split(".")[-1], colname else: return colname, None @@ -829,20 +831,26 @@ class SQLiteDialect(default.DefaultDialect): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) master = '%s.sqlite_master' % qschema - s = ("SELECT name FROM %s " - "WHERE type='table' ORDER BY name") % (master,) - rs = connection.execute(s) else: - try: - s = ("SELECT name FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) - except exc.DBAPIError: - s = ("SELECT name FROM sqlite_master " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) + master = "sqlite_master" + s = ("SELECT name FROM %s " + "WHERE type='table' ORDER BY name") % (master,) + rs = connection.execute(s) + return [row[0] for row in rs] + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + s = "SELECT name FROM sqlite_temp_master "\ + "WHERE type='table' ORDER BY name " + rs = connection.execute(s) + + return [row[0] for row in rs] + + @reflection.cache + def get_temp_view_names(self, connection, **kw): + s = "SELECT name FROM sqlite_temp_master "\ + "WHERE type='view' ORDER BY name " + rs = connection.execute(s) return [row[0] for row in rs] @@ -869,20 +877,11 @@ class SQLiteDialect(default.DefaultDialect): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) master = '%s.sqlite_master' % qschema - s = ("SELECT name FROM %s " - "WHERE type='view' ORDER BY name") % (master,) - rs = connection.execute(s) else: - try: - s = ("SELECT name FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE type='view' ORDER BY name") - rs = connection.execute(s) - except exc.DBAPIError: - s = ("SELECT name FROM sqlite_master " - "WHERE type='view' ORDER BY name") - rs = connection.execute(s) + master = "sqlite_master" + s = ("SELECT name FROM %s " + "WHERE type='view' ORDER BY name") % (master,) + rs = connection.execute(s) return [row[0] for row in rs] @@ -1097,16 +1096,24 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_unique_constraints(self, connection, table_name, schema=None, **kw): - UNIQUE_SQL = """ - SELECT sql - FROM - sqlite_master - WHERE - type='table' AND - name=:table_name - """ - c = connection.execute(UNIQUE_SQL, table_name=table_name) - table_data = c.fetchone()[0] + try: + s = ("SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type = 'table'") % table_name + rs = connection.execute(s) + except exc.DBAPIError: + s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type = 'table'") % table_name + rs = connection.execute(s) + row = rs.fetchone() + if row is None: + # sqlite won't return the schema for the sqlite_master or + # sqlite_temp_master tables from this query. These tables + # don't have any unique constraints anyway. + return [] + table_data = row[0] UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)' return [ diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py new file mode 100644 index 000000000..3c55a1de7 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -0,0 +1,116 @@ +# sqlite/pysqlcipher.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sqlite+pysqlcipher + :name: pysqlcipher + :dbapi: pysqlcipher + :connectstring: sqlite+pysqlcipher://:passphrase/file_path[?kdf_iter=<iter>] + :url: https://pypi.python.org/pypi/pysqlcipher + + ``pysqlcipher`` is a fork of the standard ``pysqlite`` driver to make + use of the `SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend. + + .. versionadded:: 0.9.9 + +Driver +------ + +The driver here is the `pysqlcipher <https://pypi.python.org/pypi/pysqlcipher>`_ +driver, which makes use of the SQLCipher engine. This system essentially +introduces new PRAGMA commands to SQLite which allows the setting of a +passphrase and other encryption parameters, allowing the database +file to be encrypted. + +Connect Strings +--------------- + +The format of the connect string is in every way the same as that +of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the +"password" field is now accepted, which should contain a passphrase:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db') + +For an absolute file path, two leading slashes should be used for the +database name:: + + e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db') + +A selection of additional encryption-related pragmas supported by SQLCipher +as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed +in the query string, and will result in that PRAGMA being called for each +new connection. Currently, ``cipher``, ``kdf_iter`` +``cipher_page_size`` and ``cipher_use_hmac`` are supported:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000') + + +Pooling Behavior +---------------- + +The driver makes a change to the default pool behavior of pysqlite +as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver +has been observed to be significantly slower on connection than the +pysqlite driver, most likely due to the encryption overhead, so the +dialect here defaults to using the :class:`.SingletonThreadPool` +implementation, +instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool +implementation is entirely configurable using the +:paramref:`.create_engine.poolclass` parameter; the :class:`.StaticPool` may +be more feasible for single-threaded use, or :class:`.NullPool` may be used +to prevent unencrypted connections from being held open for long periods of +time, at the expense of slower startup time for new connections. + + +""" +from __future__ import absolute_import +from .pysqlite import SQLiteDialect_pysqlite +from ...engine import url as _url +from ... import pool + + +class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): + driver = 'pysqlcipher' + + pragmas = ('kdf_iter', 'cipher', 'cipher_page_size', 'cipher_use_hmac') + + @classmethod + def dbapi(cls): + from pysqlcipher import dbapi2 as sqlcipher + return sqlcipher + + @classmethod + def get_pool_class(cls, url): + return pool.SingletonThreadPool + + def connect(self, *cargs, **cparams): + passphrase = cparams.pop('passphrase', '') + + pragmas = dict( + (key, cparams.pop(key, None)) for key in + self.pragmas + ) + + conn = super(SQLiteDialect_pysqlcipher, self).\ + connect(*cargs, **cparams) + conn.execute('pragma key="%s"' % passphrase) + for prag, value in pragmas.items(): + if value is not None: + conn.execute('pragma %s=%s' % (prag, value)) + + return conn + + def create_connect_args(self, url): + super_url = _url.URL( + url.drivername, username=url.username, + host=url.host, database=url.database, query=url.query) + c_args, opts = super(SQLiteDialect_pysqlcipher, self).\ + create_connect_args(super_url) + opts['passphrase'] = url.password + return c_args, opts + +dialect = SQLiteDialect_pysqlcipher diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index d2cc8890f..dd82be1d1 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -45,7 +45,7 @@ class Connection(Connectable): """ def __init__(self, engine, connection=None, close_with_result=False, - _branch=False, _execution_options=None, + _branch_from=None, _execution_options=None, _dispatch=None, _has_events=None): """Construct a new Connection. @@ -57,48 +57,80 @@ class Connection(Connectable): """ self.engine = engine self.dialect = engine.dialect - self.__connection = connection or engine.raw_connection() - self.__transaction = None - self.should_close_with_result = close_with_result - self.__savepoint_seq = 0 - self.__branch = _branch - self.__invalid = False - self.__can_reconnect = True - if _dispatch: + self.__branch_from = _branch_from + self.__branch = _branch_from is not None + + if _branch_from: + self.__connection = connection + self._execution_options = _execution_options + self._echo = _branch_from._echo + self.should_close_with_result = False self.dispatch = _dispatch - elif _has_events is None: - # if _has_events is sent explicitly as False, - # then don't join the dispatch of the engine; we don't - # want to handle any of the engine's events in that case. - self.dispatch = self.dispatch._join(engine.dispatch) - self._has_events = _has_events or ( - _has_events is None and engine._has_events) - - self._echo = self.engine._should_log_info() - if _execution_options: - self._execution_options =\ - engine._execution_options.union(_execution_options) + self._has_events = _branch_from._has_events else: + self.__connection = connection \ + if connection is not None else engine.raw_connection() + self.__transaction = None + self.__savepoint_seq = 0 + self.should_close_with_result = close_with_result + self.__invalid = False + self.__can_reconnect = True + self._echo = self.engine._should_log_info() + + if _has_events is None: + # if _has_events is sent explicitly as False, + # then don't join the dispatch of the engine; we don't + # want to handle any of the engine's events in that case. + self.dispatch = self.dispatch._join(engine.dispatch) + self._has_events = _has_events or ( + _has_events is None and engine._has_events) + + assert not _execution_options self._execution_options = engine._execution_options if self._has_events or self.engine._has_events: - self.dispatch.engine_connect(self, _branch) + self.dispatch.engine_connect(self, self.__branch) def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, and also whose close() method does nothing. - This is used to execute "sub" statements within a single execution, - usually an INSERT statement. + The Core uses this very sparingly, only in the case of + custom SQL default functions that are to be INSERTed as the + primary key of a row where we need to get the value back, so we have + to invoke it distinctly - this is a very uncommon case. + + Userland code accesses _branch() when the connect() or + contextual_connect() methods are called. The branched connection + acts as much as possible like the parent, except that it stays + connected when a close() event occurs. + """ + if self.__branch_from: + return self.__branch_from._branch() + else: + return self.engine._connection_cls( + self.engine, + self.__connection, + _branch_from=self, + _execution_options=self._execution_options, + _has_events=self._has_events, + _dispatch=self.dispatch) + + @property + def _root(self): + """return the 'root' connection. - return self.engine._connection_cls( - self.engine, - self.__connection, - _branch=True, - _has_events=self._has_events, - _dispatch=self.dispatch) + Returns 'self' if this connection is not a branch, else + returns the root connection from which we ultimately branched. + + """ + + if self.__branch_from: + return self.__branch_from + else: + return self def _clone(self): """Create a shallow copy of this Connection. @@ -224,7 +256,7 @@ class Connection(Connectable): def invalidated(self): """Return True if this connection was invalidated.""" - return self.__invalid + return self._root.__invalid @property def connection(self): @@ -236,6 +268,9 @@ class Connection(Connectable): return self._revalidate_connection() def _revalidate_connection(self): + if self.__branch_from: + return self.__branch_from._revalidate_connection() + if self.__can_reconnect and self.__invalid: if self.__transaction is not None: raise exc.InvalidRequestError( @@ -343,16 +378,17 @@ class Connection(Connectable): :ref:`pool_connection_invalidation` """ + if self.invalidated: return if self.closed: raise exc.ResourceClosedError("This Connection is closed") - if self._connection_is_valid: - self.__connection.invalidate(exception) - del self.__connection - self.__invalid = True + if self._root._connection_is_valid: + self._root.__connection.invalidate(exception) + del self._root.__connection + self._root.__invalid = True def detach(self): """Detach the underlying DB-API connection from its connection pool. @@ -415,6 +451,8 @@ class Connection(Connectable): :class:`.Engine`. """ + if self.__branch_from: + return self.__branch_from.begin() if self.__transaction is None: self.__transaction = RootTransaction(self) @@ -436,6 +474,9 @@ class Connection(Connectable): See also :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`. """ + if self.__branch_from: + return self.__branch_from.begin_nested() + if self.__transaction is None: self.__transaction = RootTransaction(self) else: @@ -459,6 +500,9 @@ class Connection(Connectable): """ + if self.__branch_from: + return self.__branch_from.begin_twophase(xid=xid) + if self.__transaction is not None: raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " @@ -479,10 +523,11 @@ class Connection(Connectable): def in_transaction(self): """Return True if a transaction is in progress.""" - - return self.__transaction is not None + return self._root.__transaction is not None def _begin_impl(self, transaction): + assert not self.__branch_from + if self._echo: self.engine.logger.info("BEGIN (implicit)") @@ -497,6 +542,8 @@ class Connection(Connectable): self._handle_dbapi_exception(e, None, None, None, None) def _rollback_impl(self): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.rollback(self) @@ -516,6 +563,8 @@ class Connection(Connectable): self.__transaction = None def _commit_impl(self, autocommit=False): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.commit(self) @@ -532,6 +581,8 @@ class Connection(Connectable): self.__transaction = None def _savepoint_impl(self, name=None): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.savepoint(self, name) @@ -543,6 +594,8 @@ class Connection(Connectable): return name def _rollback_to_savepoint_impl(self, name, context): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.rollback_savepoint(self, name, context) @@ -551,6 +604,8 @@ class Connection(Connectable): self.__transaction = context def _release_savepoint_impl(self, name, context): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.release_savepoint(self, name, context) @@ -559,6 +614,8 @@ class Connection(Connectable): self.__transaction = context def _begin_twophase_impl(self, transaction): + assert not self.__branch_from + if self._echo: self.engine.logger.info("BEGIN TWOPHASE (implicit)") if self._has_events or self.engine._has_events: @@ -571,6 +628,8 @@ class Connection(Connectable): self.connection._reset_agent = transaction def _prepare_twophase_impl(self, xid): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.prepare_twophase(self, xid) @@ -579,6 +638,8 @@ class Connection(Connectable): self.engine.dialect.do_prepare_twophase(self, xid) def _rollback_twophase_impl(self, xid, is_prepared): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.rollback_twophase(self, xid, is_prepared) @@ -595,6 +656,8 @@ class Connection(Connectable): self.__transaction = None def _commit_twophase_impl(self, xid, is_prepared): + assert not self.__branch_from + if self._has_events or self.engine._has_events: self.dispatch.commit_twophase(self, xid, is_prepared) @@ -610,8 +673,8 @@ class Connection(Connectable): self.__transaction = None def _autorollback(self): - if not self.in_transaction(): - self._rollback_impl() + if not self._root.in_transaction(): + self._root._rollback_impl() def close(self): """Close this :class:`.Connection`. @@ -632,13 +695,21 @@ class Connection(Connectable): and will allow no further operations. """ + if self.__branch_from: + try: + del self.__connection + except AttributeError: + pass + finally: + self.__can_reconnect = False + return try: conn = self.__connection except AttributeError: pass else: - if not self.__branch: - conn.close() + + conn.close() if conn._reset_agent is self.__transaction: conn._reset_agent = None @@ -993,8 +1064,8 @@ class Connection(Connectable): result.rowcount result.close(_autoclose_connection=False) - if self.__transaction is None and context.should_autocommit: - self._commit_impl(autocommit=True) + if context.should_autocommit and self._root.__transaction is None: + self._root._commit_impl(autocommit=True) if result.closed and self.should_close_with_result: self.close() @@ -1055,8 +1126,6 @@ class Connection(Connectable): """ try: cursor.close() - except (SystemExit, KeyboardInterrupt): - raise except Exception: # log the error through the connection pool's logger. self.engine.pool.logger.error( diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 71df29cac..0ad2efae0 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -308,7 +308,15 @@ class Dialect(object): def get_table_names(self, connection, schema=None, **kw): """Return a list of table names for `schema`.""" - raise NotImplementedError + raise NotImplementedError() + + def get_temp_table_names(self, connection, schema=None, **kw): + """Return a list of temporary table names on the given connection, + if supported by the underlying backend. + + """ + + raise NotImplementedError() def get_view_names(self, connection, schema=None, **kw): """Return a list of all view names available in the database. @@ -319,6 +327,14 @@ class Dialect(object): raise NotImplementedError() + def get_temp_view_names(self, connection, schema=None, **kw): + """Return a list of temporary view names on the given connection, + if supported by the underlying backend. + + """ + + raise NotImplementedError() + def get_view_definition(self, connection, view_name, schema=None, **kw): """Return view definition. diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index cf1f2d3dd..2a1def86a 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -201,6 +201,30 @@ class Inspector(object): tnames = list(topological.sort(tuples, tnames)) return tnames + def get_temp_table_names(self): + """return a list of temporary table names for the current bind. + + This method is unsupported by most dialects; currently + only SQLite implements it. + + .. versionadded:: 1.0.0 + + """ + return self.dialect.get_temp_table_names( + self.bind, info_cache=self.info_cache) + + def get_temp_view_names(self): + """return a list of temporary view names for the current bind. + + This method is unsupported by most dialects; currently + only SQLite implements it. + + .. versionadded:: 1.0.0 + + """ + return self.dialect.get_temp_view_names( + self.bind, info_cache=self.info_cache) + def get_table_options(self, table_name, schema=None, **kw): """Return a dictionary of options specified when the table of the given name was created. @@ -465,55 +489,87 @@ class Inspector(object): for col_d in self.get_columns( table_name, schema, **table.dialect_kwargs): found_table = True - orig_name = col_d['name'] - table.dispatch.column_reflect(self, table, col_d) + self._reflect_column( + table, col_d, include_columns, + exclude_columns, cols_by_orig_name) - name = col_d['name'] - if include_columns and name not in include_columns: - continue - if exclude_columns and name in exclude_columns: - continue + if not found_table: + raise exc.NoSuchTableError(table.name) - coltype = col_d['type'] + self._reflect_pk( + table_name, schema, table, cols_by_orig_name, exclude_columns) - col_kw = dict( - (k, col_d[k]) - for k in ['nullable', 'autoincrement', 'quote', 'info', 'key'] - if k in col_d - ) + self._reflect_fk( + table_name, schema, table, cols_by_orig_name, + exclude_columns, reflection_options) - colargs = [] - if col_d.get('default') is not None: - # the "default" value is assumed to be a literal SQL - # expression, so is wrapped in text() so that no quoting - # occurs on re-issuance. - colargs.append( - sa_schema.DefaultClause( - sql.text(col_d['default']), _reflected=True - ) - ) + self._reflect_indexes( + table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options) - if 'sequence' in col_d: - # TODO: mssql and sybase are using this. - seq = col_d['sequence'] - sequence = sa_schema.Sequence(seq['name'], 1, 1) - if 'start' in seq: - sequence.start = seq['start'] - if 'increment' in seq: - sequence.increment = seq['increment'] - colargs.append(sequence) + self._reflect_unique_constraints( + table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options) - cols_by_orig_name[orig_name] = col = \ - sa_schema.Column(name, coltype, *colargs, **col_kw) + def _reflect_column( + self, table, col_d, include_columns, + exclude_columns, cols_by_orig_name): - if col.key in table.primary_key: - col.primary_key = True - table.append_column(col) + orig_name = col_d['name'] - if not found_table: - raise exc.NoSuchTableError(table.name) + table.dispatch.column_reflect(self, table, col_d) + + # fetch name again as column_reflect is allowed to + # change it + name = col_d['name'] + if (include_columns and name not in include_columns) \ + or (exclude_columns and name in exclude_columns): + return + + coltype = col_d['type'] + + col_kw = dict( + (k, col_d[k]) + for k in ['nullable', 'autoincrement', 'quote', 'info', 'key'] + if k in col_d + ) + + colargs = [] + if col_d.get('default') is not None: + # the "default" value is assumed to be a literal SQL + # expression, so is wrapped in text() so that no quoting + # occurs on re-issuance. + colargs.append( + sa_schema.DefaultClause( + sql.text(col_d['default']), _reflected=True + ) + ) + if 'sequence' in col_d: + self._reflect_col_sequence(col_d, colargs) + + cols_by_orig_name[orig_name] = col = \ + sa_schema.Column(name, coltype, *colargs, **col_kw) + + if col.key in table.primary_key: + col.primary_key = True + table.append_column(col) + + def _reflect_col_sequence(self, col_d, colargs): + if 'sequence' in col_d: + # TODO: mssql and sybase are using this. + seq = col_d['sequence'] + sequence = sa_schema.Sequence(seq['name'], 1, 1) + if 'start' in seq: + sequence.start = seq['start'] + if 'increment' in seq: + sequence.increment = seq['increment'] + colargs.append(sequence) + + def _reflect_pk( + self, table_name, schema, table, + cols_by_orig_name, exclude_columns): pk_cons = self.get_pk_constraint( table_name, schema, **table.dialect_kwargs) if pk_cons: @@ -530,6 +586,9 @@ class Inspector(object): # its column collection table.primary_key._reload(pk_cols) + def _reflect_fk( + self, table_name, schema, table, cols_by_orig_name, + exclude_columns, reflection_options): fkeys = self.get_foreign_keys( table_name, schema, **table.dialect_kwargs) for fkey_d in fkeys: @@ -572,6 +631,10 @@ class Inspector(object): sa_schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True, **options)) + + def _reflect_indexes( + self, table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options): # Indexes indexes = self.get_indexes(table_name, schema) for index_d in indexes: @@ -579,12 +642,15 @@ class Inspector(object): columns = index_d['column_names'] unique = index_d['unique'] flavor = index_d.get('type', 'index') + duplicates = index_d.get('duplicates_constraint') if include_columns and \ not set(columns).issubset(include_columns): util.warn( "Omitting %s key for (%s), key covers omitted columns." % (flavor, ', '.join(columns))) continue + if duplicates: + continue # look for columns by orig name in cols_by_orig_name, # but support columns that are in-Python only as fallback idx_cols = [] @@ -602,3 +668,43 @@ class Inspector(object): idx_cols.append(idx_col) sa_schema.Index(name, *idx_cols, **dict(unique=unique)) + + def _reflect_unique_constraints( + self, table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options): + + # Unique Constraints + try: + constraints = self.get_unique_constraints(table_name, schema) + except NotImplementedError: + # optional dialect feature + return + + for const_d in constraints: + conname = const_d['name'] + columns = const_d['column_names'] + duplicates = const_d.get('duplicates_index') + if include_columns and \ + not set(columns).issubset(include_columns): + util.warn( + "Omitting unique constraint key for (%s), " + "key covers omitted columns." % + ', '.join(columns)) + continue + if duplicates: + continue + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + constrained_cols = [] + for c in columns: + try: + constrained_col = cols_by_orig_name[c] \ + if c in cols_by_orig_name else table.c[c] + except KeyError: + util.warn( + "unique constraint key '%s' was not located in " + "columns for table '%s'" % (c, table_name)) + else: + constrained_cols.append(constrained_col) + table.append_constraint( + sa_schema.UniqueConstraint(*constrained_cols, name=conname)) diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 38206be89..398ef8df6 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -162,6 +162,7 @@ class DefaultEngineStrategy(EngineStrategy): def first_connect(dbapi_connection, connection_record): c = base.Connection(engine, connection=dbapi_connection, _has_events=False) + c._execution_options = util.immutabledict() dialect.initialize(c) event.listen(pool, 'first_connect', first_connect, once=True) diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index dba1063cf..be2a82208 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -319,14 +319,12 @@ class _ListenerCollection(RefCollection, _CompoundListener): registry._stored_in_collection_multi(self, other, to_associate) def insert(self, event_key, propagate): - if event_key._listen_fn not in self.listeners: - event_key.prepend_to_list(self, self.listeners) + if event_key.prepend_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) def append(self, event_key, propagate): - if event_key._listen_fn not in self.listeners: - event_key.append_to_list(self, self.listeners) + if event_key.append_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index ba2f671a3..5b422c401 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -71,13 +71,15 @@ def _stored_in_collection(event_key, owner): listen_ref = weakref.ref(event_key._listen_fn) if owner_ref in dispatch_reg: - assert dispatch_reg[owner_ref] == listen_ref - else: - dispatch_reg[owner_ref] = listen_ref + return False + + dispatch_reg[owner_ref] = listen_ref listener_to_key = _collection_to_key[owner_ref] listener_to_key[listen_ref] = key + return True + def _removed_from_collection(event_key, owner): key = event_key._key @@ -180,6 +182,17 @@ class _EventKey(object): def listen(self, *args, **kw): once = kw.pop("once", False) + named = kw.pop("named", False) + + target, identifier, fn = \ + self.dispatch_target, self.identifier, self._listen_fn + + dispatch_descriptor = getattr(target.dispatch, identifier) + + adjusted_fn = dispatch_descriptor._adjust_fn_spec(fn, named) + + self = self.with_wrapper(adjusted_fn) + if once: self.with_wrapper( util.only_once(self._listen_fn)).listen(*args, **kw) @@ -215,9 +228,6 @@ class _EventKey(object): dispatch_descriptor = getattr(target.dispatch, identifier) - fn = dispatch_descriptor._adjust_fn_spec(fn, named) - self = self.with_wrapper(fn) - if insert: dispatch_descriptor.\ for_modify(target.dispatch).insert(self, propagate) @@ -229,18 +239,20 @@ class _EventKey(object): def _listen_fn(self): return self.fn_wrap or self.fn - def append_value_to_list(self, owner, list_, value): - _stored_in_collection(self, owner) - list_.append(value) - def append_to_list(self, owner, list_): - _stored_in_collection(self, owner) - list_.append(self._listen_fn) + if _stored_in_collection(self, owner): + list_.append(self._listen_fn) + return True + else: + return False def remove_from_list(self, owner, list_): _removed_from_collection(self, owner) list_.remove(self._listen_fn) def prepend_to_list(self, owner, list_): - _stored_in_collection(self, owner) - list_.appendleft(self._listen_fn) + if _stored_in_collection(self, owner): + list_.appendleft(self._listen_fn) + return True + else: + return False diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 1ecec51b6..b4f057b0a 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -338,7 +338,7 @@ class PoolEvents(event.Events): """ - def reset(self, dbapi_connnection, connection_record): + def reset(self, dbapi_connection, connection_record): """Called before the "reset" action occurs for a pooled connection. This event represents @@ -470,7 +470,8 @@ class ConnectionEvents(event.Events): @classmethod def _listen(cls, event_key, retval=False): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn target._has_events = True diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index a82bae33f..3271d09d4 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -238,14 +238,16 @@ class StatementError(SQLAlchemyError): def __str__(self): from sqlalchemy.sql import util - params_repr = util._repr_params(self.params, 10) + details = [SQLAlchemyError.__str__(self)] + if self.statement: + details.append("[SQL: %r]" % self.statement) + if self.params: + params_repr = util._repr_params(self.params, 10) + details.append("[parameters: %r]" % params_repr) return ' '.join([ "(%s)" % det for det in self.detail - ] + [ - SQLAlchemyError.__str__(self), - repr(self.statement), repr(params_repr) - ]) + ] + details) def __unicode__(self): return self.__str__() @@ -280,17 +282,19 @@ class DBAPIError(StatementError): connection_invalidated=False): # Don't ever wrap these, just return them directly as if # DBAPIError didn't exist. - if isinstance(orig, (KeyboardInterrupt, SystemExit, DontWrapMixin)): + if (isinstance(orig, BaseException) and + not isinstance(orig, Exception)) or \ + isinstance(orig, DontWrapMixin): return orig if orig is not None: # not a DBAPI error, statement is present. # raise a StatementError if not isinstance(orig, dbapi_base_err) and statement: - msg = traceback.format_exception_only( - orig.__class__, orig)[-1].strip() return StatementError( - "%s (original cause: %s)" % (str(orig), msg), + "(%s.%s) %s" % + (orig.__class__.__module__, orig.__class__.__name__, + orig), statement, params, orig ) @@ -310,13 +314,12 @@ class DBAPIError(StatementError): def __init__(self, statement, params, orig, connection_invalidated=False): try: text = str(orig) - except (KeyboardInterrupt, SystemExit): - raise except Exception as e: text = 'Error in str() of DB-API-generated exception: ' + str(e) StatementError.__init__( self, - '(%s) %s' % (orig.__class__.__name__, text), + '(%s.%s) %s' % ( + orig.__class__.__module__, orig.__class__.__name__, text, ), statement, params, orig diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 121285ab3..c11795d37 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -243,7 +243,26 @@ follows: one-to-many backref will be created on the referred class referring to this class. -4. The names of the relationships are determined using the +4. If any of the columns that are part of the :class:`.ForeignKeyConstraint` + are not nullable (e.g. ``nullable=False``), a + :paramref:`~.relationship.cascade` keyword argument + of ``all, delete-orphan`` will be added to the keyword arguments to + be passed to the relationship or backref. If the + :class:`.ForeignKeyConstraint` reports that + :paramref:`.ForeignKeyConstraint.ondelete` + is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable + set of columns, the option :paramref:`~.relationship.passive_deletes` + flag is set to ``True`` in the set of relationship keyword arguments. + Note that not all backends support reflection of ON DELETE. + + .. versionadded:: 1.0.0 - automap will detect non-nullable foreign key + constraints when producing a one-to-many relationship and establish + a default cascade of ``all, delete-orphan`` if so; additionally, + if the constraint specifies :paramref:`.ForeignKeyConstraint.ondelete` + of ``CASCADE`` for non-nullable or ``SET NULL`` for nullable columns, + the ``passive_deletes=True`` option is also added. + +5. The names of the relationships are determined using the :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and :paramref:`.AutomapBase.prepare.name_for_collection_relationship` callable functions. It is important to note that the default relationship @@ -252,18 +271,18 @@ follows: alternate class naming scheme, that's the name from which the relationship name will be derived. -5. The classes are inspected for an existing mapped property matching these +6. The classes are inspected for an existing mapped property matching these names. If one is detected on one side, but none on the other side, :class:`.AutomapBase` attempts to create a relationship on the missing side, then uses the :paramref:`.relationship.back_populates` parameter in order to point the new relationship to the other side. -6. In the usual case where no relationship is on either side, +7. In the usual case where no relationship is on either side, :meth:`.AutomapBase.prepare` produces a :func:`.relationship` on the "many-to-one" side and matches it to the other using the :paramref:`.relationship.backref` parameter. -7. Production of the :func:`.relationship` and optionally the :func:`.backref` +8. Production of the :func:`.relationship` and optionally the :func:`.backref` is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship` function, which can be supplied by the end-user in order to augment the arguments passed to :func:`.relationship` or :func:`.backref` or to @@ -877,6 +896,19 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, constraint ) + o2m_kws = {} + nullable = False not in set([fk.parent.nullable for fk in fks]) + if not nullable: + o2m_kws['cascade'] = "all, delete-orphan" + + if constraint.ondelete and \ + constraint.ondelete.lower() == "cascade": + o2m_kws['passive_deletes'] = True + else: + if constraint.ondelete and \ + constraint.ondelete.lower() == "set null": + o2m_kws['passive_deletes'] = True + create_backref = backref_name not in referred_cfg.properties if relationship_name not in map_config.properties: @@ -885,7 +917,8 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, automap_base, interfaces.ONETOMANY, backref, backref_name, referred_cls, local_cls, - collection_class=collection_class) + collection_class=collection_class, + **o2m_kws) else: backref_obj = None rel = generate_relationship(automap_base, @@ -916,7 +949,8 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, fk.parent for fk in constraint.elements], back_populates=relationship_name, - collection_class=collection_class) + collection_class=collection_class, + **o2m_kws) if rel is not None: referred_cfg.properties[backref_name] = rel map_config.properties[ diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index 3cbc85c0c..2b611252a 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -873,8 +873,7 @@ the method without the need to copy it. Columns generated by :class:`~.declared_attr` can also be referenced by ``__mapper_args__`` to a limited degree, currently -by ``polymorphic_on`` and ``version_id_col``, by specifying the -classdecorator itself into the dictionary - the declarative extension +by ``polymorphic_on`` and ``version_id_col``; the declarative extension will resolve them at class construction time:: class MyMixin: @@ -889,7 +888,6 @@ will resolve them at class construction time:: id = Column(Integer, primary_key=True) - Mixing in Relationships ~~~~~~~~~~~~~~~~~~~~~~~ @@ -922,6 +920,7 @@ reference a common target class via many-to-one:: __tablename__ = 'target' id = Column(Integer, primary_key=True) + Using Advanced Relationship Arguments (e.g. ``primaryjoin``, etc.) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -1004,6 +1003,24 @@ requirement so that no reliance on copying is needed:: class Something(SomethingMixin, Base): __tablename__ = "something" +The :func:`.column_property` or other construct may refer +to other columns from the mixin. These are copied ahead of time before +the :class:`.declared_attr` is invoked:: + + class SomethingMixin(object): + x = Column(Integer) + + y = Column(Integer) + + @declared_attr + def x_plus_y(cls): + return column_property(cls.x + cls.y) + + +.. versionchanged:: 1.0.0 mixin columns are copied to the final mapped class + so that :class:`.declared_attr` methods can access the actual column + that will be mapped. + Mixing in Association Proxy and Other Attributes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1087,19 +1104,20 @@ and ``TypeB`` classes. Controlling table inheritance with mixins ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The ``__tablename__`` attribute in conjunction with the hierarchy of -classes involved in a declarative mixin scenario controls what type of -table inheritance, if any, -is configured by the declarative extension. +The ``__tablename__`` attribute may be used to provide a function that +will determine the name of the table used for each class in an inheritance +hierarchy, as well as whether a class has its own distinct table. -If the ``__tablename__`` is computed by a mixin, you may need to -control which classes get the computed attribute in order to get the -type of table inheritance you require. +This is achieved using the :class:`.declared_attr` indicator in conjunction +with a method named ``__tablename__()``. Declarative will always +invoke :class:`.declared_attr` for the special names +``__tablename__``, ``__mapper_args__`` and ``__table_args__`` +function **for each mapped class in the hierarchy**. The function therefore +needs to expect to receive each class individually and to provide the +correct answer for each. -For example, if you had a mixin that computes ``__tablename__`` but -where you wanted to use that mixin in a single table inheritance -hierarchy, you can explicitly specify ``__tablename__`` as ``None`` to -indicate that the class should not have a table mapped:: +For example, to create a mixin that gives every class a simple table +name based on class name:: from sqlalchemy.ext.declarative import declared_attr @@ -1118,15 +1136,10 @@ indicate that the class should not have a table mapped:: __mapper_args__ = {'polymorphic_identity': 'engineer'} primary_language = Column(String(50)) -Alternatively, you can make the mixin intelligent enough to only -return a ``__tablename__`` in the event that no table is already -mapped in the inheritance hierarchy. To help with this, a -:func:`~sqlalchemy.ext.declarative.has_inherited_table` helper -function is provided that returns ``True`` if a parent class already -has a mapped table. - -As an example, here's a mixin that will only allow single table -inheritance:: +Alternatively, we can modify our ``__tablename__`` function to return +``None`` for subclasses, using :func:`.has_inherited_table`. This has +the effect of those subclasses being mapped with single table inheritance +agaisnt the parent:: from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import has_inherited_table @@ -1147,6 +1160,64 @@ inheritance:: primary_language = Column(String(50)) __mapper_args__ = {'polymorphic_identity': 'engineer'} +.. _mixin_inheritance_columns: + +Mixing in Columns in Inheritance Scenarios +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In constrast to how ``__tablename__`` and other special names are handled when +used with :class:`.declared_attr`, when we mix in columns and properties (e.g. +relationships, column properties, etc.), the function is +invoked for the **base class only** in the hierarchy. Below, only the +``Person`` class will receive a column +called ``id``; the mapping will fail on ``Engineer``, which is not given +a primary key:: + + class HasId(object): + @declared_attr + def id(cls): + return Column('id', Integer, primary_key=True) + + class Person(HasId, Base): + __tablename__ = 'person' + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __tablename__ = 'engineer' + primary_language = Column(String(50)) + __mapper_args__ = {'polymorphic_identity': 'engineer'} + +It is usually the case in joined-table inheritance that we want distinctly +named columns on each subclass. However in this case, we may want to have +an ``id`` column on every table, and have them refer to each other via +foreign key. We can achieve this as a mixin by using the +:attr:`.declared_attr.cascading` modifier, which indicates that the +function should be invoked **for each class in the hierarchy**, just like +it does for ``__tablename__``:: + + class HasId(object): + @declared_attr.cascading + def id(cls): + if has_inherited_table(cls): + return Column('id', + Integer, + ForeignKey('person.id'), primary_key=True) + else: + return Column('id', Integer, primary_key=True) + + class Person(HasId, Base): + __tablename__ = 'person' + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on': discriminator} + + class Engineer(Person): + __tablename__ = 'engineer' + primary_language = Column(String(50)) + __mapper_args__ = {'polymorphic_identity': 'engineer'} + + +.. versionadded:: 1.0.0 added :attr:`.declared_attr.cascading`. Combining Table/Mapper Arguments from Multiple Mixins ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py index daf8bffb5..66fe05fd0 100644 --- a/lib/sqlalchemy/ext/declarative/api.py +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -8,12 +8,13 @@ from ...schema import Table, MetaData -from ...orm import synonym as _orm_synonym, mapper,\ +from ...orm import synonym as _orm_synonym, \ comparable_property,\ - interfaces, properties + interfaces, properties, attributes from ...orm.util import polymorphic_union from ...orm.base import _mapper_or_none -from ...util import OrderedDict +from ...util import OrderedDict, hybridmethod, hybridproperty +from ... import util from ... import exc import weakref @@ -21,7 +22,6 @@ from .base import _as_declarative, \ _declarative_constructor,\ _DeferredMapperConfig, _add_attribute from .clsregistry import _class_resolver -from . import clsregistry def instrument_declarative(cls, registry, metadata): @@ -157,12 +157,98 @@ class declared_attr(interfaces._MappedAttribute, property): """ - def __init__(self, fget, *arg, **kw): - super(declared_attr, self).__init__(fget, *arg, **kw) + def __init__(self, fget, cascading=False): + super(declared_attr, self).__init__(fget) self.__doc__ = fget.__doc__ + self._cascading = cascading def __get__(desc, self, cls): - return desc.fget(cls) + # use the ClassManager for memoization of values. This is better than + # adding yet another attribute onto the class, or using weakrefs + # here which are slow and take up memory. It also allows us to + # warn for non-mapped use of declared_attr. + + manager = attributes.manager_of_class(cls) + if manager is None: + util.warn( + "Unmanaged access of declarative attribute %s from " + "non-mapped class %s" % + (desc.fget.__name__, cls.__name__)) + return desc.fget(cls) + try: + reg = manager.info['declared_attr_reg'] + except KeyError: + raise exc.InvalidRequestError( + "@declared_attr called outside of the " + "declarative mapping process; is declarative_base() being " + "used correctly?") + + if desc in reg: + return reg[desc] + else: + reg[desc] = obj = desc.fget(cls) + return obj + + @hybridmethod + def _stateful(cls, **kw): + return _stateful_declared_attr(**kw) + + @hybridproperty + def cascading(cls): + """Mark a :class:`.declared_attr` as cascading. + + This is a special-use modifier which indicates that a column + or MapperProperty-based declared attribute should be configured + distinctly per mapped subclass, within a mapped-inheritance scenario. + + Below, both MyClass as well as MySubClass will have a distinct + ``id`` Column object established:: + + class HasSomeAttribute(object): + @declared_attr.cascading + def some_id(cls): + if has_inherited_table(cls): + return Column( + ForeignKey('myclass.id'), primary_key=True) + else: + return Column(Integer, primary_key=True) + + return Column('id', Integer, primary_key=True) + + class MyClass(HasSomeAttribute, Base): + "" + # ... + + class MySubClass(MyClass): + "" + # ... + + The behavior of the above configuration is that ``MySubClass`` + will refer to both its own ``id`` column as well as that of + ``MyClass`` underneath the attribute named ``some_id``. + + .. seealso:: + + :ref:`declarative_inheritance` + + :ref:`mixin_inheritance_columns` + + + """ + return cls._stateful(cascading=True) + + +class _stateful_declared_attr(declared_attr): + def __init__(self, **kw): + self.kw = kw + + def _stateful(self, **kw): + new_kw = self.kw.copy() + new_kw.update(kw) + return _stateful_declared_attr(**new_kw) + + def __call__(self, fn): + return declared_attr(fn, **self.kw) def declarative_base(bind=None, metadata=None, mapper=None, cls=object, @@ -349,9 +435,11 @@ class AbstractConcreteBase(ConcreteBase): ``__declare_last__()`` function, which is essentially a hook for the :meth:`.after_configured` event. - :class:`.AbstractConcreteBase` does not produce a mapped - table for the class itself. Compare to :class:`.ConcreteBase`, - which does. + :class:`.AbstractConcreteBase` does produce a mapped class + for the base class, however it is not persisted to any table; it + is instead mapped directly to the "polymorphic" selectable directly + and is only used for selecting. Compare to :class:`.ConcreteBase`, + which does create a persisted table for the base class. Example:: @@ -365,20 +453,72 @@ class AbstractConcreteBase(ConcreteBase): employee_id = Column(Integer, primary_key=True) name = Column(String(50)) manager_data = Column(String(40)) + __mapper_args__ = { - 'polymorphic_identity':'manager', - 'concrete':True} + 'polymorphic_identity':'manager', + 'concrete':True} + + The abstract base class is handled by declarative in a special way; + at class configuration time, it behaves like a declarative mixin + or an ``__abstract__`` base class. Once classes are configured + and mappings are produced, it then gets mapped itself, but + after all of its decscendants. This is a very unique system of mapping + not found in any other SQLAlchemy system. + + Using this approach, we can specify columns and properties + that will take place on mapped subclasses, in the way that + we normally do as in :ref:`declarative_mixins`:: + + class Company(Base): + __tablename__ = 'company' + id = Column(Integer, primary_key=True) + + class Employee(AbstractConcreteBase, Base): + employee_id = Column(Integer, primary_key=True) + + @declared_attr + def company_id(cls): + return Column(ForeignKey('company.id')) + + @declared_attr + def company(cls): + return relationship("Company") + + class Manager(Employee): + __tablename__ = 'manager' + + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + When we make use of our mappings however, both ``Manager`` and + ``Employee`` will have an independently usable ``.company`` attribute:: + + session.query(Employee).filter(Employee.company.has(id=5)) + + .. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase` + have been reworked to support relationships established directly + on the abstract base, without any special configurational steps. + """ - __abstract__ = True + __no_table__ = True @classmethod def __declare_first__(cls): - if hasattr(cls, '__mapper__'): + cls._sa_decl_prepare_nocascade() + + @classmethod + def _sa_decl_prepare_nocascade(cls): + if getattr(cls, '__mapper__', None): return - clsregistry.add_class(cls.__name__, cls) + to_map = _DeferredMapperConfig.config_for_cls(cls) + # can't rely on 'self_and_descendants' here # since technically an immediate subclass # might not be mapped, but a subclass @@ -392,11 +532,22 @@ class AbstractConcreteBase(ConcreteBase): if mn is not None: mappers.append(mn) pjoin = cls._create_polymorphic_union(mappers) - cls.__mapper__ = m = mapper(cls, pjoin, polymorphic_on=pjoin.c.type) + + to_map.local_table = pjoin + + m_args = to_map.mapper_args_fn or dict + + def mapper_args(): + args = m_args() + args['polymorphic_on'] = pjoin.c.type + return args + to_map.mapper_args_fn = mapper_args + + m = to_map.map() for scls in cls.__subclasses__(): sm = _mapper_or_none(scls) - if sm.concrete and cls in scls.__bases__: + if sm and sm.concrete and cls in scls.__bases__: sm._set_concrete_base(m) diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index 94baeeb51..291608b6c 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -19,6 +19,9 @@ from ... import event from . import clsregistry import collections import weakref +from sqlalchemy.orm import instrumentation + +declared_attr = declarative_props = None def _declared_mapping_info(cls): @@ -32,322 +35,407 @@ def _declared_mapping_info(cls): return None +def _get_immediate_cls_attr(cls, attrname): + """return an attribute of the class that is either present directly + on the class, e.g. not on a superclass, or is from a superclass but + this superclass is a mixin, that is, not a descendant of + the declarative base. + + This is used to detect attributes that indicate something about + a mapped class independently from any mapped classes that it may + inherit from. + + """ + for base in cls.__mro__: + _is_declarative_inherits = hasattr(base, '_decl_class_registry') + if attrname in base.__dict__: + value = getattr(base, attrname) + if (base is cls or + (base in cls.__bases__ and not _is_declarative_inherits)): + return value + else: + return None + + def _as_declarative(cls, classname, dict_): - from .api import declared_attr + global declared_attr, declarative_props + if declared_attr is None: + from .api import declared_attr + declarative_props = (declared_attr, util.classproperty) - # dict_ will be a dictproxy, which we can't write to, and we need to! - dict_ = dict(dict_) + if _get_immediate_cls_attr(cls, '__abstract__'): + return - column_copies = {} - potential_columns = {} + _MapperConfig.setup_mapping(cls, classname, dict_) - mapper_args_fn = None - table_args = inherited_table_args = None - tablename = None - declarative_props = (declared_attr, util.classproperty) +class _MapperConfig(object): - for base in cls.__mro__: - _is_declarative_inherits = hasattr(base, '_decl_class_registry') + @classmethod + def setup_mapping(cls, cls_, classname, dict_): + defer_map = _get_immediate_cls_attr( + cls_, '_sa_decl_prepare_nocascade') or \ + hasattr(cls_, '_sa_decl_prepare') - if '__declare_last__' in base.__dict__: - @event.listens_for(mapper, "after_configured") - def go(): - cls.__declare_last__() - if '__declare_first__' in base.__dict__: - @event.listens_for(mapper, "before_configured") - def go(): - cls.__declare_first__() - if '__abstract__' in base.__dict__ and base.__abstract__: - if (base is cls or - (base in cls.__bases__ and not _is_declarative_inherits)): - return + if defer_map: + cfg_cls = _DeferredMapperConfig + else: + cfg_cls = _MapperConfig + cfg_cls(cls_, classname, dict_) - class_mapped = _declared_mapping_info(base) is not None + def __init__(self, cls_, classname, dict_): - for name, obj in vars(base).items(): - if name == '__mapper_args__': - if not mapper_args_fn and ( - not class_mapped or - isinstance(obj, declarative_props) - ): - # don't even invoke __mapper_args__ until - # after we've determined everything about the - # mapped table. - # make a copy of it so a class-level dictionary - # is not overwritten when we update column-based - # arguments. - mapper_args_fn = lambda: dict(cls.__mapper_args__) - elif name == '__tablename__': - if not tablename and ( - not class_mapped or - isinstance(obj, declarative_props) - ): - tablename = cls.__tablename__ - elif name == '__table_args__': - if not table_args and ( - not class_mapped or - isinstance(obj, declarative_props) - ): - table_args = cls.__table_args__ - if not isinstance(table_args, (tuple, dict, type(None))): - raise exc.ArgumentError( - "__table_args__ value must be a tuple, " - "dict, or None") - if base is not cls: - inherited_table_args = True - elif class_mapped: - if isinstance(obj, declarative_props): - util.warn("Regular (i.e. not __special__) " - "attribute '%s.%s' uses @declared_attr, " - "but owning class %s is mapped - " - "not applying to subclass %s." - % (base.__name__, name, base, cls)) - continue - elif base is not cls: - # we're a mixin. - if isinstance(obj, Column): - if getattr(cls, name) is not obj: - # if column has been overridden - # (like by the InstrumentedAttribute of the - # superclass), skip + self.cls = cls_ + + # dict_ will be a dictproxy, which we can't write to, and we need to! + self.dict_ = dict(dict_) + self.classname = classname + self.mapped_table = None + self.properties = util.OrderedDict() + self.declared_columns = set() + self.column_copies = {} + self._setup_declared_events() + + # register up front, so that @declared_attr can memoize + # function evaluations in .info + manager = instrumentation.register_class(self.cls) + manager.info['declared_attr_reg'] = {} + + self._scan_attributes() + + clsregistry.add_class(self.classname, self.cls) + + self._extract_mappable_attributes() + + self._extract_declared_columns() + + self._setup_table() + + self._setup_inheritance() + + self._early_mapping() + + def _early_mapping(self): + self.map() + + def _setup_declared_events(self): + if _get_immediate_cls_attr(self.cls, '__declare_last__'): + @event.listens_for(mapper, "after_configured") + def after_configured(): + self.cls.__declare_last__() + + if _get_immediate_cls_attr(self.cls, '__declare_first__'): + @event.listens_for(mapper, "before_configured") + def before_configured(): + self.cls.__declare_first__() + + def _scan_attributes(self): + cls = self.cls + dict_ = self.dict_ + column_copies = self.column_copies + mapper_args_fn = None + table_args = inherited_table_args = None + tablename = None + + for base in cls.__mro__: + class_mapped = base is not cls and \ + _declared_mapping_info(base) is not None and \ + not _get_immediate_cls_attr(base, '_sa_decl_prepare_nocascade') + + if not class_mapped and base is not cls: + self._produce_column_copies(base) + + for name, obj in vars(base).items(): + if name == '__mapper_args__': + if not mapper_args_fn and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + # make a copy of it so a class-level dictionary + # is not overwritten when we update column-based + # arguments. + mapper_args_fn = lambda: dict(cls.__mapper_args__) + elif name == '__tablename__': + if not tablename and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + tablename = cls.__tablename__ + elif name == '__table_args__': + if not table_args and ( + not class_mapped or + isinstance(obj, declarative_props) + ): + table_args = cls.__table_args__ + if not isinstance( + table_args, (tuple, dict, type(None))): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None") + if base is not cls: + inherited_table_args = True + elif class_mapped: + if isinstance(obj, declarative_props): + util.warn("Regular (i.e. not __special__) " + "attribute '%s.%s' uses @declared_attr, " + "but owning class %s is mapped - " + "not applying to subclass %s." + % (base.__name__, name, base, cls)) + continue + elif base is not cls: + # we're a mixin, abstract base, or something that is + # acting like that for now. + if isinstance(obj, Column): + # already copied columns to the mapped class. continue - if obj.foreign_keys: + elif isinstance(obj, MapperProperty): raise exc.InvalidRequestError( - "Columns with foreign keys to other columns " - "must be declared as @declared_attr callables " - "on declarative mixin classes. ") - if name not in dict_ and not ( - '__table__' in dict_ and - (obj.name or name) in dict_['__table__'].c - ) and name not in potential_columns: - potential_columns[name] = \ - column_copies[obj] = \ - obj.copy() - column_copies[obj]._creation_order = \ - obj._creation_order - elif isinstance(obj, MapperProperty): + "Mapper properties (i.e. deferred," + "column_property(), relationship(), etc.) must " + "be declared as @declared_attr callables " + "on declarative mixin classes.") + elif isinstance(obj, declarative_props): + oldclassprop = isinstance(obj, util.classproperty) + if not oldclassprop and obj._cascading: + dict_[name] = column_copies[obj] = \ + ret = obj.__get__(obj, cls) + else: + if oldclassprop: + util.warn_deprecated( + "Use of sqlalchemy.util.classproperty on " + "declarative classes is deprecated.") + dict_[name] = column_copies[obj] = \ + ret = getattr(cls, name) + if isinstance(ret, (Column, MapperProperty)) and \ + ret.doc is None: + ret.doc = obj.__doc__ + + if inherited_table_args and not tablename: + table_args = None + + self.table_args = table_args + self.tablename = tablename + self.mapper_args_fn = mapper_args_fn + + def _produce_column_copies(self, base): + cls = self.cls + dict_ = self.dict_ + column_copies = self.column_copies + # copy mixin columns to the mapped class + for name, obj in vars(base).items(): + if isinstance(obj, Column): + if getattr(cls, name) is not obj: + # if column has been overridden + # (like by the InstrumentedAttribute of the + # superclass), skip + continue + elif obj.foreign_keys: raise exc.InvalidRequestError( - "Mapper properties (i.e. deferred," - "column_property(), relationship(), etc.) must " - "be declared as @declared_attr callables " - "on declarative mixin classes.") - elif isinstance(obj, declarative_props): - dict_[name] = ret = \ - column_copies[obj] = getattr(cls, name) - if isinstance(ret, (Column, MapperProperty)) and \ - ret.doc is None: - ret.doc = obj.__doc__ - - # apply inherited columns as we should - for k, v in potential_columns.items(): - dict_[k] = v - - if inherited_table_args and not tablename: - table_args = None - - clsregistry.add_class(classname, cls) - our_stuff = util.OrderedDict() - - for k in list(dict_): - - # TODO: improve this ? all dunders ? - if k in ('__table__', '__tablename__', '__mapper_args__'): - continue - - value = dict_[k] - if isinstance(value, declarative_props): - value = getattr(cls, k) - - elif isinstance(value, QueryableAttribute) and \ - value.class_ is not cls and \ - value.key != k: - # detect a QueryableAttribute that's already mapped being - # assigned elsewhere in userland, turn into a synonym() - value = synonym(value.key) - setattr(cls, k, value) - - if (isinstance(value, tuple) and len(value) == 1 and - isinstance(value[0], (Column, MapperProperty))): - util.warn("Ignoring declarative-like tuple value of attribute " - "%s: possibly a copy-and-paste error with a comma " - "left at the end of the line?" % k) - continue - if not isinstance(value, (Column, MapperProperty)): - if not k.startswith('__'): - dict_.pop(k) - setattr(cls, k, value) - continue - if k == 'metadata': - raise exc.InvalidRequestError( - "Attribute name 'metadata' is reserved " - "for the MetaData instance when using a " - "declarative base class." - ) - prop = clsregistry._deferred_relationship(cls, value) - our_stuff[k] = prop - - # set up attributes in the order they were created - our_stuff.sort(key=lambda key: our_stuff[key]._creation_order) - - # extract columns from the class dict - declared_columns = set() - name_to_prop_key = collections.defaultdict(set) - for key, c in list(our_stuff.items()): - if isinstance(c, (ColumnProperty, CompositeProperty)): - for col in c.columns: - if isinstance(col, Column) and \ - col.table is None: - _undefer_column_name(key, col) - if not isinstance(c, CompositeProperty): - name_to_prop_key[col.name].add(key) - declared_columns.add(col) - elif isinstance(c, Column): - _undefer_column_name(key, c) - name_to_prop_key[c.name].add(key) - declared_columns.add(c) - # if the column is the same name as the key, - # remove it from the explicit properties dict. - # the normal rules for assigning column-based properties - # will take over, including precedence of columns - # in multi-column ColumnProperties. - if key == c.key: - del our_stuff[key] - - for name, keys in name_to_prop_key.items(): - if len(keys) > 1: - util.warn( - "On class %r, Column object %r named directly multiple times, " - "only one will be used: %s" % - (classname, name, (", ".join(sorted(keys)))) - ) + "Columns with foreign keys to other columns " + "must be declared as @declared_attr callables " + "on declarative mixin classes. ") + elif name not in dict_ and not ( + '__table__' in dict_ and + (obj.name or name) in dict_['__table__'].c + ): + column_copies[obj] = copy_ = obj.copy() + copy_._creation_order = obj._creation_order + setattr(cls, name, copy_) + dict_[name] = copy_ - declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order) - table = None + def _extract_mappable_attributes(self): + cls = self.cls + dict_ = self.dict_ - if hasattr(cls, '__table_cls__'): - table_cls = util.unbound_method_to_callable(cls.__table_cls__) - else: - table_cls = Table - - if '__table__' not in dict_: - if tablename is not None: - - args, table_kw = (), {} - if table_args: - if isinstance(table_args, dict): - table_kw = table_args - elif isinstance(table_args, tuple): - if isinstance(table_args[-1], dict): - args, table_kw = table_args[0:-1], table_args[-1] - else: - args = table_args - - autoload = dict_.get('__autoload__') - if autoload: - table_kw['autoload'] = True - - cls.__table__ = table = table_cls( - tablename, cls.metadata, - *(tuple(declared_columns) + tuple(args)), - **table_kw) - else: - table = cls.__table__ - if declared_columns: - for c in declared_columns: - if not table.c.contains_column(c): - raise exc.ArgumentError( - "Can't add additional column %r when " - "specifying __table__" % c.key - ) + our_stuff = self.properties - if hasattr(cls, '__mapper_cls__'): - mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) - else: - mapper_cls = mapper + for k in list(dict_): - for c in cls.__bases__: - if _declared_mapping_info(c) is not None: - inherits = c - break - else: - inherits = None + if k in ('__table__', '__tablename__', '__mapper_args__'): + continue - if table is None and inherits is None: - raise exc.InvalidRequestError( - "Class %r does not have a __table__ or __tablename__ " - "specified and does not inherit from an existing " - "table-mapped class." % cls - ) - elif inherits: - inherited_mapper = _declared_mapping_info(inherits) - inherited_table = inherited_mapper.local_table - inherited_mapped_table = inherited_mapper.mapped_table - - if table is None: - # single table inheritance. - # ensure no table args - if table_args: - raise exc.ArgumentError( - "Can't place __table_args__ on an inherited class " - "with no table." + value = dict_[k] + if isinstance(value, declarative_props): + value = getattr(cls, k) + + elif isinstance(value, QueryableAttribute) and \ + value.class_ is not cls and \ + value.key != k: + # detect a QueryableAttribute that's already mapped being + # assigned elsewhere in userland, turn into a synonym() + value = synonym(value.key) + setattr(cls, k, value) + + if (isinstance(value, tuple) and len(value) == 1 and + isinstance(value[0], (Column, MapperProperty))): + util.warn("Ignoring declarative-like tuple value of attribute " + "%s: possibly a copy-and-paste error with a comma " + "left at the end of the line?" % k) + continue + elif not isinstance(value, (Column, MapperProperty)): + # using @declared_attr for some object that + # isn't Column/MapperProperty; remove from the dict_ + # and place the evaulated value onto the class. + if not k.startswith('__'): + dict_.pop(k) + setattr(cls, k, value) + continue + # we expect to see the name 'metadata' in some valid cases; + # however at this point we see it's assigned to something trying + # to be mapped, so raise for that. + elif k == 'metadata': + raise exc.InvalidRequestError( + "Attribute name 'metadata' is reserved " + "for the MetaData instance when using a " + "declarative base class." + ) + prop = clsregistry._deferred_relationship(cls, value) + our_stuff[k] = prop + + def _extract_declared_columns(self): + our_stuff = self.properties + + # set up attributes in the order they were created + our_stuff.sort(key=lambda key: our_stuff[key]._creation_order) + + # extract columns from the class dict + declared_columns = self.declared_columns + name_to_prop_key = collections.defaultdict(set) + for key, c in list(our_stuff.items()): + if isinstance(c, (ColumnProperty, CompositeProperty)): + for col in c.columns: + if isinstance(col, Column) and \ + col.table is None: + _undefer_column_name(key, col) + if not isinstance(c, CompositeProperty): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + elif isinstance(c, Column): + _undefer_column_name(key, c) + name_to_prop_key[c.name].add(key) + declared_columns.add(c) + # if the column is the same name as the key, + # remove it from the explicit properties dict. + # the normal rules for assigning column-based properties + # will take over, including precedence of columns + # in multi-column ColumnProperties. + if key == c.key: + del our_stuff[key] + + for name, keys in name_to_prop_key.items(): + if len(keys) > 1: + util.warn( + "On class %r, Column object %r named " + "directly multiple times, " + "only one will be used: %s" % + (self.classname, name, (", ".join(sorted(keys)))) ) - # add any columns declared here to the inherited table. - for c in declared_columns: - if c.primary_key: - raise exc.ArgumentError( - "Can't place primary key columns on an inherited " - "class with no table." - ) - if c.name in inherited_table.c: - if inherited_table.c[c.name] is c: - continue - raise exc.ArgumentError( - "Column '%s' on class %s conflicts with " - "existing column '%s'" % - (c, cls, inherited_table.c[c.name]) - ) - inherited_table.append_column(c) - if inherited_mapped_table is not None and \ - inherited_mapped_table is not inherited_table: - inherited_mapped_table._refresh_for_new_column(c) - - defer_map = hasattr(cls, '_sa_decl_prepare') - if defer_map: - cfg_cls = _DeferredMapperConfig - else: - cfg_cls = _MapperConfig - mt = cfg_cls(mapper_cls, - cls, table, - inherits, - declared_columns, - column_copies, - our_stuff, - mapper_args_fn) - if not defer_map: - mt.map() + def _setup_table(self): + cls = self.cls + tablename = self.tablename + table_args = self.table_args + dict_ = self.dict_ + declared_columns = self.declared_columns -class _MapperConfig(object): + declared_columns = self.declared_columns = sorted( + declared_columns, key=lambda c: c._creation_order) + table = None - mapped_table = None - - def __init__(self, mapper_cls, - cls, - table, - inherits, - declared_columns, - column_copies, - properties, mapper_args_fn): - self.mapper_cls = mapper_cls - self.cls = cls + if hasattr(cls, '__table_cls__'): + table_cls = util.unbound_method_to_callable(cls.__table_cls__) + else: + table_cls = Table + + if '__table__' not in dict_: + if tablename is not None: + + args, table_kw = (), {} + if table_args: + if isinstance(table_args, dict): + table_kw = table_args + elif isinstance(table_args, tuple): + if isinstance(table_args[-1], dict): + args, table_kw = table_args[0:-1], table_args[-1] + else: + args = table_args + + autoload = dict_.get('__autoload__') + if autoload: + table_kw['autoload'] = True + + cls.__table__ = table = table_cls( + tablename, cls.metadata, + *(tuple(declared_columns) + tuple(args)), + **table_kw) + else: + table = cls.__table__ + if declared_columns: + for c in declared_columns: + if not table.c.contains_column(c): + raise exc.ArgumentError( + "Can't add additional column %r when " + "specifying __table__" % c.key + ) self.local_table = table - self.inherits = inherits - self.properties = properties - self.mapper_args_fn = mapper_args_fn - self.declared_columns = declared_columns - self.column_copies = column_copies + + def _setup_inheritance(self): + table = self.local_table + cls = self.cls + table_args = self.table_args + declared_columns = self.declared_columns + for c in cls.__bases__: + if _declared_mapping_info(c) is not None and \ + not _get_immediate_cls_attr( + c, '_sa_decl_prepare_nocascade'): + self.inherits = c + break + else: + self.inherits = None + + if table is None and self.inherits is None and \ + not _get_immediate_cls_attr(cls, '__no_table__'): + + raise exc.InvalidRequestError( + "Class %r does not have a __table__ or __tablename__ " + "specified and does not inherit from an existing " + "table-mapped class." % cls + ) + elif self.inherits: + inherited_mapper = _declared_mapping_info(self.inherits) + inherited_table = inherited_mapper.local_table + inherited_mapped_table = inherited_mapper.mapped_table + + if table is None: + # single table inheritance. + # ensure no table args + if table_args: + raise exc.ArgumentError( + "Can't place __table_args__ on an inherited class " + "with no table." + ) + # add any columns declared here to the inherited table. + for c in declared_columns: + if c.primary_key: + raise exc.ArgumentError( + "Can't place primary key columns on an inherited " + "class with no table." + ) + if c.name in inherited_table.c: + if inherited_table.c[c.name] is c: + continue + raise exc.ArgumentError( + "Column '%s' on class %s conflicts with " + "existing column '%s'" % + (c, cls, inherited_table.c[c.name]) + ) + inherited_table.append_column(c) + if inherited_mapped_table is not None and \ + inherited_mapped_table is not inherited_table: + inherited_mapped_table._refresh_for_new_column(c) def _prepare_mapper_arguments(self): properties = self.properties @@ -401,20 +489,31 @@ class _MapperConfig(object): properties[k] = [col] + p.columns result_mapper_args = mapper_args.copy() result_mapper_args['properties'] = properties - return result_mapper_args + self.mapper_args = result_mapper_args def map(self): - mapper_args = self._prepare_mapper_arguments() - self.cls.__mapper__ = self.mapper_cls( + self._prepare_mapper_arguments() + if hasattr(self.cls, '__mapper_cls__'): + mapper_cls = util.unbound_method_to_callable( + self.cls.__mapper_cls__) + else: + mapper_cls = mapper + + self.cls.__mapper__ = mp_ = mapper_cls( self.cls, self.local_table, - **mapper_args + **self.mapper_args ) + del mp_.class_manager.info['declared_attr_reg'] + return mp_ class _DeferredMapperConfig(_MapperConfig): _configs = util.OrderedDict() + def _early_mapping(self): + pass + @property def cls(self): return self._cls() @@ -466,7 +565,7 @@ class _DeferredMapperConfig(_MapperConfig): def map(self): self._configs.pop(self._cls, None) - super(_DeferredMapperConfig, self).map() + return super(_DeferredMapperConfig, self).map() def _add_attribute(cls, key, value): diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 4595b857a..3ef63a5ae 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -103,7 +103,12 @@ class _MultipleClassMarker(object): self.on_remove() def add_item(self, item): - modules = set([cls().__module__ for cls in self.contents]) + # protect against class registration race condition against + # asynchronous garbage collection calling _remove_item, + # [ticket:3208] + modules = set([ + cls.__module__ for cls in + [ref() for ref in self.contents] if cls is not None]) if item.__module__ in modules: util.warn( "This declarative base already contains a class with the " diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index c2754d58f..356a8a3b9 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -861,11 +861,24 @@ def _instrument_class(cls): "Can not instrument a built-in type. Use a " "subclass, even a trivial one.") + roles, methods = _locate_roles_and_methods(cls) + + _setup_canned_roles(cls, roles, methods) + + _assert_required_roles(cls, roles, methods) + + _set_collection_attributes(cls, roles, methods) + + +def _locate_roles_and_methods(cls): + """search for _sa_instrument_role-decorated methods in + method resolution order, assign to roles. + + """ + roles = {} methods = {} - # search for _sa_instrument_role-decorated methods in - # method resolution order, assign to roles for supercls in cls.__mro__: for name, method in vars(supercls).items(): if not util.callable(method): @@ -890,14 +903,19 @@ def _instrument_class(cls): assert op in ('fire_append_event', 'fire_remove_event') after = op if before: - methods[name] = before[0], before[1], after + methods[name] = before + (after, ) elif after: methods[name] = None, None, after + return roles, methods + - # see if this class has "canned" roles based on a known - # collection type (dict, set, list). Apply those roles - # as needed to the "roles" dictionary, and also - # prepare "decorator" methods +def _setup_canned_roles(cls, roles, methods): + """see if this class has "canned" roles based on a known + collection type (dict, set, list). Apply those roles + as needed to the "roles" dictionary, and also + prepare "decorator" methods + + """ collection_type = util.duck_type_collection(cls) if collection_type in __interfaces: canned_roles, decorators = __interfaces[collection_type] @@ -911,8 +929,12 @@ def _instrument_class(cls): not hasattr(fn, '_sa_instrumented')): setattr(cls, method, decorator(fn)) - # ensure all roles are present, and apply implicit instrumentation if - # needed + +def _assert_required_roles(cls, roles, methods): + """ensure all roles are present, and apply implicit instrumentation if + needed + + """ if 'appender' not in roles or not hasattr(cls, roles['appender']): raise sa_exc.ArgumentError( "Type %s must elect an appender method to be " @@ -934,8 +956,12 @@ def _instrument_class(cls): "Type %s must elect an iterator method to be " "a collection class" % cls.__name__) - # apply ad-hoc instrumentation from decorators, class-level defaults - # and implicit role declarations + +def _set_collection_attributes(cls, roles, methods): + """apply ad-hoc instrumentation from decorators, class-level defaults + and implicit role declarations + + """ for method_name, (before, argument, after) in methods.items(): setattr(cls, method_name, _instrument_membership_mutator(getattr(cls, method_name), diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index c50a7b062..9ea0dd834 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -61,7 +61,8 @@ class InstrumentationEvents(event.Events): @classmethod def _listen(cls, event_key, propagate=True, **kw): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn def listen(target_cls, *arg): listen_cls = target() @@ -192,7 +193,8 @@ class InstanceEvents(event.Events): @classmethod def _listen(cls, event_key, raw=False, propagate=False, **kw): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn if not raw: def wrap(state, *arg, **kw): @@ -498,7 +500,8 @@ class MapperEvents(event.Events): def _listen( cls, event_key, raw=False, retval=False, propagate=False, **kw): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn if identifier in ("before_configured", "after_configured") and \ target is not mapperlib.Mapper: @@ -1493,7 +1496,8 @@ class AttributeEvents(event.Events): propagate=False): target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + event_key.dispatch_target, event_key.identifier, \ + event_key._listen_fn if active_history: target.dispatch._active_history = True diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 984f05256..082dae054 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -426,6 +426,12 @@ class Mapper(InspectionAttr): thus persisting the value to the ``discriminator`` column in the database. + .. warning:: + + Currently, **only one discriminator column may be set**, typically + on the base-most class in the hierarchy. "Cascading" polymorphic + columns are not yet supported. + .. seealso:: :ref:`inheritance_toplevel` @@ -1080,6 +1086,9 @@ class Mapper(InspectionAttr): auto-session attachment logic. """ + + # when using declarative as of 1.0, the register_class has + # already happened from within declarative. manager = attributes.manager_of_class(self.class_) if self.non_primary: @@ -1102,18 +1111,14 @@ class Mapper(InspectionAttr): "create a non primary Mapper. clear_mappers() will " "remove *all* current mappers from all classes." % self.class_) - # else: - # a ClassManager may already exist as - # ClassManager.instrument_attribute() creates - # new managers for each subclass if they don't yet exist. + + if manager is None: + manager = instrumentation.register_class(self.class_) _mapper_registry[self] = True self.dispatch.instrument_class(self, self.class_) - if manager is None: - manager = instrumentation.register_class(self.class_) - self.class_manager = manager manager.mapper = self @@ -2657,7 +2662,7 @@ def configure_mappers(): mapper._expire_memoizations() mapper.dispatch.mapper_configured( mapper, mapper.class_) - except: + except Exception: exc = sys.exc_info()[1] if not hasattr(exc, '_configure_failed'): mapper._configure_failed = exc diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 1288c910f..c4a9402fb 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -18,7 +18,7 @@ import operator from itertools import groupby, chain from .. import sql, util, exc as sa_exc, schema from . import attributes, sync, exc as orm_exc, evaluator -from .base import state_str, _attr_as_key +from .base import state_str, _attr_as_key, _entity_descriptor from ..sql import expression from . import loading @@ -560,9 +560,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, state, state_dict, col) if value is None: raise orm_exc.FlushError( - "Can't delete from table " + "Can't delete from table %s " "using NULL for primary " - "key value") + "key value on column %s" % (table, col)) if update_version_id is not None and \ mapper.version_id_col in mapper._cols_by_table[table]: @@ -1112,6 +1112,7 @@ class BulkUpdate(BulkUD): super(BulkUpdate, self).__init__(query) self.query._no_select_modifiers("update") self.values = values + self.mapper = self.query._mapper_zero_or_none() @classmethod def factory(cls, query, synchronize_session, values): @@ -1121,9 +1122,40 @@ class BulkUpdate(BulkUD): False: BulkUpdate }, synchronize_session, query, values) + def _resolve_string_to_expr(self, key): + if self.mapper and isinstance(key, util.string_types): + attr = _entity_descriptor(self.mapper, key) + return attr.__clause_element__() + else: + return key + + def _resolve_key_to_attrname(self, key): + if self.mapper and isinstance(key, util.string_types): + attr = _entity_descriptor(self.mapper, key) + return attr.property.key + elif isinstance(key, attributes.InstrumentedAttribute): + return key.key + elif hasattr(key, '__clause_element__'): + key = key.__clause_element__() + + if self.mapper and isinstance(key, expression.ColumnElement): + try: + attr = self.mapper._columntoproperty[key] + except orm_exc.UnmappedColumnError: + return None + else: + return attr.key + else: + raise sa_exc.InvalidRequestError( + "Invalid expression type: %r" % key) + def _do_exec(self): + values = dict( + (self._resolve_string_to_expr(k), v) + for k, v in self.values.items() + ) update_stmt = sql.update(self.primary_table, - self.context.whereclause, self.values) + self.context.whereclause, values) self.result = self.query.session.execute( update_stmt, params=self.query._params) @@ -1169,9 +1201,10 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): def _additional_evaluators(self, evaluator_compiler): self.value_evaluators = {} for key, value in self.values.items(): - key = _attr_as_key(key) - self.value_evaluators[key] = evaluator_compiler.process( - expression._literal_as_binds(value)) + key = self._resolve_key_to_attrname(key) + if key is not None: + self.value_evaluators[key] = evaluator_compiler.process( + expression._literal_as_binds(value)) def _do_post_synchronize(self): session = self.query.session diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 60948293b..f07060825 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1145,7 +1145,8 @@ class Query(object): @_generative() def with_hint(self, selectable, text, dialect_name='*'): - """Add an indexing hint for the given entity or selectable to + """Add an indexing or other executional context + hint for the given entity or selectable to this :class:`.Query`. Functionality is passed straight through to @@ -1153,11 +1154,35 @@ class Query(object): with the addition that ``selectable`` can be a :class:`.Table`, :class:`.Alias`, or ORM entity / mapped class /etc. + + .. seealso:: + + :meth:`.Query.with_statement_hint` + """ - selectable = inspect(selectable).selectable + if selectable is not None: + selectable = inspect(selectable).selectable self._with_hints += ((selectable, text, dialect_name),) + def with_statement_hint(self, text, dialect_name='*'): + """add a statement hint to this :class:`.Select`. + + This method is similar to :meth:`.Select.with_hint` except that + it does not require an individual table, and instead applies to the + statement as a whole. + + This feature calls down into :meth:`.Select.with_statement_hint`. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Query.with_hint` + + """ + return self.with_hint(None, text, dialect_name) + @_generative() def execution_options(self, **kwargs): """ Set non-SQL options which take effect during execution. @@ -1810,6 +1835,11 @@ class Query(object): left_entity = prop = None + if isinstance(onclause, interfaces.PropComparator): + of_type = getattr(onclause, '_of_type', None) + else: + of_type = None + if isinstance(onclause, util.string_types): left_entity = self._joinpoint_zero() @@ -1836,8 +1866,6 @@ class Query(object): if isinstance(onclause, interfaces.PropComparator): if right_entity is None: - right_entity = onclause.property.mapper - of_type = getattr(onclause, '_of_type', None) if of_type: right_entity = of_type else: @@ -1919,11 +1947,9 @@ class Query(object): from_obj, r_info.selectable): overlap = True break - elif sql_util.selectables_overlap(l_info.selectable, - r_info.selectable): - overlap = True - if overlap and l_info.selectable is r_info.selectable: + if (overlap or not create_aliases) and \ + l_info.selectable is r_info.selectable: raise sa_exc.InvalidRequestError( "Can't join table/selectable '%s' to itself" % l_info.selectable) @@ -2591,6 +2617,19 @@ class Query(object): SELECT 1 FROM users WHERE users.name = :name_1 ) AS anon_1 + The EXISTS construct is usually used in the WHERE clause:: + + session.query(User.id).filter(q.exists()).scalar() + + Note that some databases such as SQL Server don't allow an + EXISTS expression to be present in the columns clause of a + SELECT. To select a simple boolean value based on the exists + as a WHERE, use :func:`.literal`:: + + from sqlalchemy import literal + + session.query(literal(True)).filter(q.exists()).scalar() + .. versionadded:: 0.8.1 """ @@ -2718,9 +2757,25 @@ class Query(object): Updates rows matched by this query in the database. - :param values: a dictionary with attributes names as keys and literal + E.g.:: + + sess.query(User).filter(User.age == 25).\ + update({User.age: User.age - 10}, synchronize_session='fetch') + + + sess.query(User).filter(User.age == 25).\ + update({"age": User.age - 10}, synchronize_session='evaluate') + + + :param values: a dictionary with attributes names, or alternatively + mapped attributes or SQL expressions, as keys, and literal values or sql expressions as values. + .. versionchanged:: 1.0.0 - string names in the values dictionary + are now resolved against the mapped entity; previously, these + strings were passed as literal column names with no mapper-level + translation. + :param synchronize_session: chooses the strategy to update the attributes on objects in the session. Valid values are: @@ -2758,7 +2813,7 @@ class Query(object): which normally occurs upon :meth:`.Session.commit` or can be forced by using :meth:`.Session.expire_all`. - * As of 0.8, this method will support multiple table updates, as + * The method supports multiple table updates, as detailed in :ref:`multi_table_updates`, and this behavior does extend to support updates of joined-inheritance and other multiple table mappings. However, the **join condition of an inheritance @@ -2789,12 +2844,6 @@ class Query(object): """ - # TODO: value keys need to be mapped to corresponding sql cols and - # instr.attr.s to string keys - # TODO: updates of manytoone relationships need to be converted to - # fk assignments - # TODO: cascades need handling. - update_op = persistence.BulkUpdate.factory( self, synchronize_session, values) update_op.exec_() diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 56a33742d..86f1b3f82 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -16,6 +16,7 @@ and `secondaryjoin` aspects of :func:`.relationship`. from __future__ import absolute_import from .. import sql, util, exc as sa_exc, schema, log +import weakref from .util import CascadeOptions, _orm_annotate, _orm_deannotate from . import dependency from . import attributes @@ -1532,6 +1533,7 @@ class RelationshipProperty(StrategizedProperty): self._check_cascade_settings(self._cascade) self._post_init() self._generate_backref() + self._join_condition._warn_for_conflicting_sync_targets() super(RelationshipProperty, self).do_init() self._lazy_strategy = self._get_strategy((("lazy", "select"),)) @@ -2519,6 +2521,60 @@ class JoinCondition(object): self.secondary_synchronize_pairs = \ self._deannotate_pairs(secondary_sync_pairs) + _track_overlapping_sync_targets = weakref.WeakKeyDictionary() + + def _warn_for_conflicting_sync_targets(self): + if not self.support_sync: + return + + # we would like to detect if we are synchronizing any column + # pairs in conflict with another relationship that wishes to sync + # an entirely different column to the same target. This is a + # very rare edge case so we will try to minimize the memory/overhead + # impact of this check + for from_, to_ in [ + (from_, to_) for (from_, to_) in self.synchronize_pairs + ] + [ + (from_, to_) for (from_, to_) in self.secondary_synchronize_pairs + ]: + # save ourselves a ton of memory and overhead by only + # considering columns that are subject to a overlapping + # FK constraints at the core level. This condition can arise + # if multiple relationships overlap foreign() directly, but + # we're going to assume it's typically a ForeignKeyConstraint- + # level configuration that benefits from this warning. + if len(to_.foreign_keys) < 2: + continue + + if to_ not in self._track_overlapping_sync_targets: + self._track_overlapping_sync_targets[to_] = \ + weakref.WeakKeyDictionary({self.prop: from_}) + else: + other_props = [] + prop_to_from = self._track_overlapping_sync_targets[to_] + for pr, fr_ in prop_to_from.items(): + if pr.mapper in mapperlib._mapper_registry and \ + fr_ is not from_ and \ + pr not in self.prop._reverse_property: + other_props.append((pr, fr_)) + + if other_props: + util.warn( + "relationship '%s' will copy column %s to column %s, " + "which conflicts with relationship(s): %s. " + "Consider applying " + "viewonly=True to read-only relationships, or provide " + "a primaryjoin condition marking writable columns " + "with the foreign() annotation." % ( + self.prop, + from_, to_, + ", ".join( + "'%s' (copies %s to %s)" % (pr, fr_, to_) + for (pr, fr_) in other_props) + ) + ) + self._track_overlapping_sync_targets[to_][self.prop] = from_ + @util.memoized_property def remote_columns(self): return self._gather_join_annotations("remote") diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 1611688b0..ef911824c 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -294,7 +294,7 @@ class SessionTransaction(object): for s in self.session.identity_map.all_states(): s._expire(s.dict, self.session.identity_map._modified) for s in self._deleted: - s.session_id = None + s._detach() self._deleted.clear() elif self.nested: self._parent._new.update(self._new) @@ -644,14 +644,8 @@ class Session(_SessionClassMethods): SessionExtension._adapt_listener(self, ext) if binds is not None: - for mapperortable, bind in binds.items(): - insp = inspect(mapperortable) - if insp.is_selectable: - self.bind_table(mapperortable, bind) - elif insp.is_mapper: - self.bind_mapper(mapperortable, bind) - else: - assert False + for key, bind in binds.items(): + self._add_bind(key, bind) if not self.autocommit: self.begin() @@ -1029,40 +1023,47 @@ class Session(_SessionClassMethods): # TODO: + crystallize + document resolution order # vis. bind_mapper/bind_table - def bind_mapper(self, mapper, bind): - """Bind operations for a mapper to a Connectable. - - mapper - A mapper instance or mapped class + def _add_bind(self, key, bind): + try: + insp = inspect(key) + except sa_exc.NoInspectionAvailable: + if not isinstance(key, type): + raise exc.ArgumentError( + "Not acceptable bind target: %s" % + key) + else: + self.__binds[key] = bind + else: + if insp.is_selectable: + self.__binds[insp] = bind + elif insp.is_mapper: + self.__binds[insp.class_] = bind + for selectable in insp._all_tables: + self.__binds[selectable] = bind + else: + raise exc.ArgumentError( + "Not acceptable bind target: %s" % + key) - bind - Any Connectable: a :class:`.Engine` or :class:`.Connection`. + def bind_mapper(self, mapper, bind): + """Associate a :class:`.Mapper` with a "bind", e.g. a :class:`.Engine` + or :class:`.Connection`. - All subsequent operations involving this mapper will use the given - `bind`. + The given mapper is added to a lookup used by the + :meth:`.Session.get_bind` method. """ - if isinstance(mapper, type): - mapper = class_mapper(mapper) - - self.__binds[mapper.base_mapper] = bind - for t in mapper._all_tables: - self.__binds[t] = bind + self._add_bind(mapper, bind) def bind_table(self, table, bind): - """Bind operations on a Table to a Connectable. + """Associate a :class:`.Table` with a "bind", e.g. a :class:`.Engine` + or :class:`.Connection`. - table - A :class:`.Table` instance - - bind - Any Connectable: a :class:`.Engine` or :class:`.Connection`. - - All subsequent operations involving this :class:`.Table` will use the - given `bind`. + The given mapper is added to a lookup used by the + :meth:`.Session.get_bind` method. """ - self.__binds[table] = bind + self._add_bind(table, bind) def get_bind(self, mapper=None, clause=None): """Return a "bind" to which this :class:`.Session` is bound. @@ -1116,6 +1117,7 @@ class Session(_SessionClassMethods): bound :class:`.MetaData`. """ + if mapper is clause is None: if self.bind: return self.bind @@ -1125,15 +1127,23 @@ class Session(_SessionClassMethods): "Connection, and no context was provided to locate " "a binding.") - c_mapper = mapper is not None and _class_to_mapper(mapper) or None + if mapper is not None: + try: + mapper = inspect(mapper) + except sa_exc.NoInspectionAvailable: + if isinstance(mapper, type): + raise exc.UnmappedClassError(mapper) + else: + raise - # manually bound? if self.__binds: - if c_mapper: - if c_mapper.base_mapper in self.__binds: - return self.__binds[c_mapper.base_mapper] - elif c_mapper.mapped_table in self.__binds: - return self.__binds[c_mapper.mapped_table] + if mapper: + for cls in mapper.class_.__mro__: + if cls in self.__binds: + return self.__binds[cls] + if clause is None: + clause = mapper.mapped_table + if clause is not None: for t in sql_util.find_tables(clause, include_crud=True): if t in self.__binds: @@ -1145,12 +1155,12 @@ class Session(_SessionClassMethods): if isinstance(clause, sql.expression.ClauseElement) and clause.bind: return clause.bind - if c_mapper and c_mapper.mapped_table.bind: - return c_mapper.mapped_table.bind + if mapper and mapper.mapped_table.bind: + return mapper.mapped_table.bind context = [] if mapper is not None: - context.append('mapper %s' % c_mapper) + context.append('mapper %s' % mapper) if clause is not None: context.append('SQL expression') @@ -1402,6 +1412,7 @@ class Session(_SessionClassMethods): state._detach() elif self.transaction: self.transaction._deleted.pop(state, None) + state._detach() def _register_newly_persistent(self, states): for state in states: @@ -2478,16 +2489,19 @@ def make_transient_to_detached(instance): def object_session(instance): - """Return the ``Session`` to which instance belongs. + """Return the :class:`.Session` to which the given instance belongs. - If the instance is not a mapped instance, an error is raised. + This is essentially the same as the :attr:`.InstanceState.session` + accessor. See that attribute for details. """ try: - return _state_session(attributes.instance_state(instance)) + state = attributes.instance_state(instance) except exc.NO_STATE: raise exc.UnmappedInstanceError(instance) + else: + return _state_session(state) _new_sessionid = util.counter() diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 3c12fda1a..560149de5 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -145,7 +145,16 @@ class InstanceState(interfaces.InspectionAttr): @util.dependencies("sqlalchemy.orm.session") def session(self, sessionlib): """Return the owning :class:`.Session` for this instance, - or ``None`` if none available.""" + or ``None`` if none available. + + Note that the result here can in some cases be *different* + from that of ``obj in session``; an object that's been deleted + will report as not ``in session``, however if the transaction is + still in progress, this attribute will still refer to that session. + Only when the transaction is completed does the object become + fully detached under normal circumstances. + + """ return sessionlib._state_session(self) @property @@ -258,8 +267,8 @@ class InstanceState(interfaces.InspectionAttr): try: return manager.original_init(*mixed[1:], **kwargs) except: - manager.dispatch.init_failure(self, args, kwargs) - raise + with util.safe_reraise(): + manager.dispatch.init_failure(self, args, kwargs) def get_history(self, key, passive): return self.manager[key].impl.get_history(self, self.dict, passive) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 734f9d5e6..ad610a4ac 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -30,13 +30,10 @@ class CascadeOptions(frozenset): 'all', 'none', 'delete-orphan']) _allowed_cascades = all_cascades - def __new__(cls, arg): - values = set([ - c for c - in re.split('\s*,\s*', arg or "") - if c - ]) - + def __new__(cls, value_list): + if isinstance(value_list, str) or value_list is None: + return cls.from_string(value_list) + values = set(value_list) if values.difference(cls._allowed_cascades): raise sa_exc.ArgumentError( "Invalid cascade option(s): %s" % @@ -70,6 +67,14 @@ class CascadeOptions(frozenset): ",".join([x for x in sorted(self)]) ) + @classmethod + def from_string(cls, arg): + values = [ + c for c + in re.split('\s*,\s*', arg or "") + if c + ] + return cls(values) def _validator_events( desc, key, validator, include_removes, include_backrefs): @@ -804,6 +809,16 @@ class _ORMJoin(expression.Join): expression.Join.__init__(self, left, right, onclause, isouter) + if not prop and getattr(right_info, 'mapper', None) \ + and right_info.mapper.single: + # if single inheritance target and we are using a manual + # or implicit ON clause, augment it the same way we'd augment the + # WHERE. + single_crit = right_info.mapper._single_table_criterion + if right_info.is_aliased_class: + single_crit = right_info._adapter.traverse(single_crit) + self.onclause = self.onclause & single_crit + def join(self, right, onclause=None, isouter=False, join_to_left=None): return _ORMJoin(self, right, onclause, isouter) diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index bc9affe4a..a174df784 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -248,9 +248,7 @@ class Pool(log.Identified): self.logger.debug("Closing connection %r", connection) try: self._dialect.do_close(connection) - except (SystemExit, KeyboardInterrupt): - raise - except: + except Exception: self.logger.error("Exception closing connection %r", connection, exc_info=True) @@ -441,8 +439,8 @@ class _ConnectionRecord(object): try: dbapi_connection = rec.get_connection() except: - rec.checkin() - raise + with util.safe_reraise(): + rec.checkin() echo = pool._should_log_debug() fairy = _ConnectionFairy(dbapi_connection, rec, echo) rec.fairy_ref = weakref.ref( @@ -569,12 +567,12 @@ def _finalize_fairy(connection, connection_record, # Immediately close detached instances if not connection_record: pool._close_connection(connection) - except Exception as e: + except BaseException as e: pool.logger.error( "Exception during reset or similar", exc_info=True) if connection_record: connection_record.invalidate(e=e) - if isinstance(e, (SystemExit, KeyboardInterrupt)): + if not isinstance(e, Exception): raise if connection_record: @@ -842,9 +840,7 @@ class SingletonThreadPool(Pool): for conn in self._all_conns: try: conn.close() - except (SystemExit, KeyboardInterrupt): - raise - except: + except Exception: # pysqlite won't even let you close a conn from a thread # that didn't create it pass @@ -962,8 +958,8 @@ class QueuePool(Pool): try: return self._create_connection() except: - self._dec_overflow() - raise + with util.safe_reraise(): + self._dec_overflow() else: return self._do_get() diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 4d013859c..351e08d0b 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -38,6 +38,7 @@ from .expression import ( false, False_, func, + funcfilter, insert, intersect, intersect_all, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5149fa4fe..5fa78ad0f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -24,12 +24,10 @@ To generate user-defined SQL strings, see """ import re -from . import schema, sqltypes, operators, functions, \ - util as sql_util, visitors, elements, selectable, base +from . import schema, sqltypes, operators, functions, visitors, \ + elements, selectable, crud from .. import util, exc -import decimal import itertools -import operator RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -64,17 +62,6 @@ BIND_TEMPLATES = { 'named': ":%(name)s" } -REQUIRED = util.symbol('REQUIRED', """ -Placeholder for the value within a :class:`.BindParameter` -which is required to be present when the statement is passed -to :meth:`.Connection.execute`. - -This symbol is typically used when a :func:`.expression.insert` -or :func:`.expression.update` statement is compiled without parameter -values present. - -""") - OPERATORS = { # binary @@ -725,7 +712,6 @@ class SQLCompiler(Compiled): for c in clauselist.clauses) if s) - def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: @@ -760,6 +746,12 @@ class SQLCompiler(Compiled): ) ) + def visit_funcfilter(self, funcfilter, **kwargs): + return "%s FILTER (WHERE %s)" % ( + funcfilter.func._compiler_dispatch(self, **kwargs), + funcfilter.criterion._compiler_dispatch(self, **kwargs) + ) + def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) return "EXTRACT(%s FROM %s)" % ( @@ -819,8 +811,9 @@ class SQLCompiler(Compiled): text += " GROUP BY " + group_by text += self.order_by_clause(cs, **kwargs) - text += (cs._limit_clause is not None or cs._offset_clause is not None) and \ - self.limit_clause(cs) or "" + text += (cs._limit_clause is not None + or cs._offset_clause is not None) and \ + self.limit_clause(cs, **kwargs) or "" if self.ctes and \ compound_index == 0 and toplevel: @@ -876,15 +869,15 @@ class SQLCompiler(Compiled): isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True - operator = binary.operator - disp = getattr(self, "visit_%s_binary" % operator.__name__, None) + operator_ = binary.operator + disp = getattr(self, "visit_%s_binary" % operator_.__name__, None) if disp: - return disp(binary, operator, **kw) + return disp(binary, operator_, **kw) else: try: - opstring = OPERATORS[operator] + opstring = OPERATORS[operator_] except KeyError: - raise exc.UnsupportedCompilationError(self, operator) + raise exc.UnsupportedCompilationError(self, operator_) else: return self._generate_generic_binary(binary, opstring, **kw) @@ -966,7 +959,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -977,7 +970,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -988,7 +981,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) @@ -999,7 +992,7 @@ class SQLCompiler(Compiled): ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_between_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) @@ -1331,6 +1324,9 @@ class SQLCompiler(Compiled): def get_crud_hint_text(self, table, text): return None + def get_statement_hint_text(self, hint_texts): + return " ".join(hint_texts) + def _transform_select_for_nested_joins(self, select): """Rewrite any "a JOIN (b JOIN c)" expression as "a JOIN (select * from b JOIN c) AS anon", to support @@ -1501,29 +1497,7 @@ class SQLCompiler(Compiled): select, transformed_select) return text - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] - - if asfrom: - froms = select._get_display_froms( - explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) - else: - froms = select._get_display_froms( - explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) - - new_correlate_froms = set(selectable._from_objects(*froms)) - all_correlate_froms = new_correlate_froms.union(correlate_froms) - - new_entry = { - 'asfrom_froms': new_correlate_froms, - 'iswrapper': iswrapper, - 'correlate_froms': all_correlate_froms, - 'selectable': select, - } - self.stack.append(new_entry) + froms = self._setup_select_stack(select, entry, asfrom, iswrapper) column_clause_args = kwargs.copy() column_clause_args.update({ @@ -1534,18 +1508,11 @@ class SQLCompiler(Compiled): text = "SELECT " # we're off to a good start ! if select._hints: - byfrom = dict([ - (from_, hinttext % { - 'name': from_._compiler_dispatch( - self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.items() - if dialect in ('*', self.dialect.name) - ]) - hint_text = self.get_select_hint_text(byfrom) + hint_text, byfrom = self._setup_select_hints(select) if hint_text: text += hint_text + " " + else: + byfrom = None if select._prefixes: text += self._generate_prefixes( @@ -1566,6 +1533,70 @@ class SQLCompiler(Compiled): if c is not None ] + text = self._compose_select_body( + text, select, inner_columns, froms, byfrom, kwargs) + + if select._statement_hints: + per_dialect = [ + ht for (dialect_name, ht) + in select._statement_hints + if dialect_name in ('*', self.dialect.name) + ] + if per_dialect: + text += " " + self.get_statement_hint_text(per_dialect) + + if self.ctes and \ + compound_index == 0 and toplevel: + text = self._render_cte_clause() + text + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def _setup_select_hints(self, select): + byfrom = dict([ + (from_, hinttext % { + 'name': from_._compiler_dispatch( + self, ashint=True) + }) + for (from_, dialect), hinttext in + select._hints.items() + if dialect in ('*', self.dialect.name) + ]) + hint_text = self.get_select_hint_text(byfrom) + return hint_text, byfrom + + def _setup_select_stack(self, select, entry, asfrom, iswrapper): + correlate_froms = entry['correlate_froms'] + asfrom_froms = entry['asfrom_froms'] + + if asfrom: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms), + implicit_correlate_froms=()) + else: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) + + new_correlate_froms = set(selectable._from_objects(*froms)) + all_correlate_froms = new_correlate_froms.union(correlate_froms) + + new_entry = { + 'asfrom_froms': new_correlate_froms, + 'iswrapper': iswrapper, + 'correlate_froms': all_correlate_froms, + 'selectable': select, + } + self.stack.append(new_entry) + return froms + + def _compose_select_body( + self, text, select, inner_columns, froms, byfrom, kwargs): text += ', '.join(inner_columns) if froms: @@ -1609,16 +1640,7 @@ class SQLCompiler(Compiled): if select._for_update_arg is not None: text += self.for_update_clause(select, **kwargs) - if self.ctes and \ - compound_index == 0 and toplevel: - text = self._render_cte_clause() + text - - self.stack.pop(-1) - - if asfrom and parens: - return "(" + text + ")" - else: - return text + return text def _generate_prefixes(self, stmt, prefixes, **kw): clause = " ".join( @@ -1708,9 +1730,9 @@ class SQLCompiler(Compiled): def visit_insert(self, insert_stmt, **kw): self.isinsert = True - colparams = self._get_colparams(insert_stmt, **kw) + crud_params = crud._get_crud_params(self, insert_stmt, **kw) - if not colparams and \ + if not crud_params and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: raise exc.CompileError("The '%s' dialect with current database " @@ -1725,9 +1747,9 @@ class SQLCompiler(Compiled): "version settings does not support " "in-place multirow inserts." % self.dialect.name) - colparams_single = colparams[0] + crud_params_single = crud_params[0] else: - colparams_single = colparams + crud_params_single = crud_params preparer = self.preparer supports_default_values = self.dialect.supports_default_values @@ -1758,9 +1780,9 @@ class SQLCompiler(Compiled): text += table_text - if colparams_single or not supports_default_values: + if crud_params_single or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in colparams_single]) + for c in crud_params_single]) if self.returning or insert_stmt._returning: self.returning = self.returning or insert_stmt._returning @@ -1771,21 +1793,21 @@ class SQLCompiler(Compiled): text += " " + returning_clause if insert_stmt.select is not None: - text += " %s" % self.process(insert_stmt.select, **kw) - elif not colparams and supports_default_values: + text += " %s" % self.process(self._insert_from_select, **kw) + elif not crud_params and supports_default_values: text += " DEFAULT VALUES" elif insert_stmt._has_multi_parameters: text += " VALUES %s" % ( ", ".join( "(%s)" % ( - ', '.join(c[1] for c in colparam_set) + ', '.join(c[1] for c in crud_param_set) ) - for colparam_set in colparams + for crud_param_set in crud_params ) ) else: text += " VALUES (%s)" % \ - ', '.join([c[1] for c in colparams]) + ', '.join([c[1] for c in crud_params]) if self.returning and not self.returning_precedes_values: text += " " + returning_clause @@ -1842,7 +1864,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, **kw) + crud_params = crud._get_crud_params(self, update_stmt, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1869,7 +1891,7 @@ class SQLCompiler(Compiled): text += ', '.join( c[0]._compiler_dispatch(self, include_table=include_table) + - '=' + c[1] for c in colparams + '=' + c[1] for c in crud_params ) if self.returning or update_stmt._returning: @@ -1905,380 +1927,9 @@ class SQLCompiler(Compiled): return text - def _create_crud_bind_param(self, col, value, required=False, name=None): - if name is None: - name = col.key - bindparam = elements.BindParameter(name, value, - type_=col.type, required=required) - bindparam._is_crud = True - return bindparam._compiler_dispatch(self) - @util.memoized_property def _key_getters_for_crud_column(self): - if self.isupdate and self.statement._extra_froms: - # when extra tables are present, refer to the columns - # in those extra tables as table-qualified, including in - # dictionaries and when rendering bind param names. - # the "main" table of the statement remains unqualified, - # allowing the most compatibility with a non-multi-table - # statement. - _et = set(self.statement._extra_froms) - - def _column_as_key(key): - str_key = elements._column_as_key(key) - if hasattr(key, 'table') and key.table in _et: - return (key.table.name, str_key) - else: - return str_key - - def _getattr_col_key(col): - if col.table in _et: - return (col.table.name, col.key) - else: - return col.key - - def _col_bind_name(col): - if col.table in _et: - return "%s_%s" % (col.table.name, col.key) - else: - return col.key - - else: - _column_as_key = elements._column_as_key - _getattr_col_key = _col_bind_name = operator.attrgetter("key") - - return _column_as_key, _getattr_col_key, _col_bind_name - - def _get_colparams(self, stmt, **kw): - """create a set of tuples representing column/string pairs for use - in an INSERT or UPDATE statement. - - Also generates the Compiled object's postfetch, prefetch, and - returning column collections, used for default handling and ultimately - populating the ResultProxy's prefetch_cols() and postfetch_cols() - collections. - - """ - - self.postfetch = [] - self.prefetch = [] - self.returning = [] - - # no parameters in the statement, no parameters in the - # compiled params - return binds for all columns - if self.column_keys is None and stmt.parameters is None: - return [ - (c, self._create_crud_bind_param(c, - None, required=True)) - for c in stmt.table.columns - ] - - if stmt._has_multi_parameters: - stmt_parameters = stmt.parameters[0] - else: - stmt_parameters = stmt.parameters - - # getters - these are normally just column.key, - # but in the case of mysql multi-table update, the rules for - # .key must conditionally take tablename into account - _column_as_key, _getattr_col_key, _col_bind_name = \ - self._key_getters_for_crud_column - - # if we have statement parameters - set defaults in the - # compiled params - if self.column_keys is None: - parameters = {} - else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in self.column_keys - if not stmt_parameters or - key not in stmt_parameters) - - # create a list of column assignment clauses as tuples - values = [] - - if stmt_parameters is not None: - for k, v in stmt_parameters.items(): - colkey = _column_as_key(k) - if colkey is not None: - parameters.setdefault(colkey, v) - else: - # a non-Column expression on the left side; - # add it to values() in an "as-is" state, - # coercing right side to bound param - if elements._is_literal(v): - v = self.process( - elements.BindParameter(None, v, type_=k.type), - **kw) - else: - v = self.process(v.self_group(), **kw) - - values.append((k, v)) - - need_pks = self.isinsert and \ - not self.inline and \ - not stmt._returning and \ - not stmt._has_multi_parameters - - implicit_returning = need_pks and \ - self.dialect.implicit_returning and \ - stmt.table.implicit_returning - - if self.isinsert: - implicit_return_defaults = (implicit_returning and - stmt._return_defaults) - elif self.isupdate: - implicit_return_defaults = (self.dialect.implicit_returning and - stmt.table.implicit_returning and - stmt._return_defaults) - else: - implicit_return_defaults = False - - if implicit_return_defaults: - if stmt._return_defaults is True: - implicit_return_defaults = set(stmt.table.c) - else: - implicit_return_defaults = set(stmt._return_defaults) - - postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid - - check_columns = {} - - # special logic that only occurs for multi-table UPDATE - # statements - if self.isupdate and stmt._extra_froms and stmt_parameters: - normalized_params = dict( - (elements._clause_element_as_expr(c), param) - for c, param in stmt_parameters.items() - ) - affected_tables = set() - for t in stmt._extra_froms: - for c in t.c: - if c in normalized_params: - affected_tables.add(t) - check_columns[_getattr_col_key(c)] = c - value = normalized_params[c] - if elements._is_literal(value): - value = self._create_crud_bind_param( - c, value, required=value is REQUIRED, - name=_col_bind_name(c)) - else: - self.postfetch.append(c) - value = self.process(value.self_group(), **kw) - values.append((c, value)) - # determine tables which are actually - # to be updated - process onupdate and - # server_onupdate for these - for t in affected_tables: - for c in t.c: - if c in normalized_params: - continue - elif (c.onupdate is not None and not - c.onupdate.is_sequence): - if c.onupdate.is_clause_element: - values.append( - (c, self.process( - c.onupdate.arg.self_group(), - **kw) - ) - ) - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param( - c, None, name=_col_bind_name(c) - ) - ) - ) - self.prefetch.append(c) - elif c.server_onupdate is not None: - self.postfetch.append(c) - - if self.isinsert and stmt.select_names: - # for an insert from select, we can only use names that - # are given, so only select for those names. - cols = (stmt.table.c[_column_as_key(name)] - for name in stmt.select_names) - else: - # iterate through all table columns to maintain - # ordering, even for those cols that aren't included - cols = stmt.table.columns - - for c in cols: - col_key = _getattr_col_key(c) - if col_key in parameters and col_key not in check_columns: - value = parameters.pop(col_key) - if elements._is_literal(value): - value = self._create_crud_bind_param( - c, value, required=value is REQUIRED, - name=_col_bind_name(c) - if not stmt._has_multi_parameters - else "%s_0" % _col_bind_name(c) - ) - else: - if isinstance(value, elements.BindParameter) and \ - value.type._isnull: - value = value._clone() - value.type = c.type - - if c.primary_key and implicit_returning: - self.returning.append(c) - value = self.process(value.self_group(), **kw) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - value = self.process(value.self_group(), **kw) - else: - self.postfetch.append(c) - value = self.process(value.self_group(), **kw) - values.append((c, value)) - - elif self.isinsert: - if c.primary_key and \ - need_pks and \ - ( - implicit_returning or - not postfetch_lastrowid or - c is not stmt.table._autoincrement_column - ): - - if implicit_returning: - if c.default is not None: - if c.default.is_sequence: - if self.dialect.supports_sequences and \ - (not c.default.optional or - not self.dialect.sequences_optional): - proc = self.process(c.default, **kw) - values.append((c, proc)) - self.returning.append(c) - elif c.default.is_clause_element: - values.append( - (c, self.process( - c.default.arg.self_group(), **kw)) - ) - self.returning.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - else: - self.returning.append(c) - else: - if ( - (c.default is not None and - (not c.default.is_sequence or - self.dialect.supports_sequences)) or - c is stmt.table._autoincrement_column and - (self.dialect.supports_sequences or - self.dialect. - preexecute_autoincrement_sequences) - ): - - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - - self.prefetch.append(c) - - elif c.default is not None: - if c.default.is_sequence: - if self.dialect.supports_sequences and \ - (not c.default.optional or - not self.dialect.sequences_optional): - proc = self.process(c.default, **kw) - values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - self.postfetch.append(c) - elif c.default.is_clause_element: - values.append( - (c, self.process( - c.default.arg.self_group(), **kw)) - ) - - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - # don't add primary key column to postfetch - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - elif c.server_default is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - elif not c.primary_key: - self.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - - elif self.isupdate: - if c.onupdate is not None and not c.onupdate.is_sequence: - if c.onupdate.is_clause_element: - values.append( - (c, self.process( - c.onupdate.arg.self_group(), **kw)) - ) - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - else: - self.postfetch.append(c) - else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - self.prefetch.append(c) - elif c.server_onupdate is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - else: - self.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: - self.returning.append(c) - - if parameters and stmt_parameters: - check = set(parameters).intersection( - _column_as_key(k) for k in stmt.parameters - ).difference(check_columns) - if check: - raise exc.CompileError( - "Unconsumed column names: %s" % - (", ".join("%s" % c for c in check)) - ) - - if stmt._has_multi_parameters: - values_0 = values - values = [values] - - values.extend( - [ - ( - c, - (self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) if elements._is_literal(row[c.key]) - else self.process( - row[c.key].self_group(), **kw)) - if c.key in row else param - ) - for (c, param) in values_0 - ] - for i, row in enumerate(stmt.parameters[1:]) - ) - - return values + return crud._key_getters_for_crud_column(self) def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), @@ -2462,17 +2113,18 @@ class DDLCompiler(Compiled): constraints.extend([c for c in table._sorted_constraints if c is not table.primary_key]) - return ", \n\t".join(p for p in - (self.process(constraint) - for constraint in constraints - if ( - constraint._create_rule is None or - constraint._create_rule(self)) - and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None - ) + return ", \n\t".join( + p for p in + (self.process(constraint) + for constraint in constraints + if ( + constraint._create_rule is None or + constraint._create_rule(self)) + and ( + not self.dialect.supports_alter or + not getattr(constraint, 'use_alter', False) + )) if p is not None + ) def visit_drop_table(self, drop): return "\nDROP TABLE " + self.preparer.format_table(drop.element) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py new file mode 100644 index 000000000..831d05be1 --- /dev/null +++ b/lib/sqlalchemy/sql/crud.py @@ -0,0 +1,530 @@ +# sql/crud.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Functions used by compiler.py to determine the parameters rendered +within INSERT and UPDATE statements. + +""" +from .. import util +from .. import exc +from . import elements +import operator + +REQUIRED = util.symbol('REQUIRED', """ +Placeholder for the value within a :class:`.BindParameter` +which is required to be present when the statement is passed +to :meth:`.Connection.execute`. + +This symbol is typically used when a :func:`.expression.insert` +or :func:`.expression.update` statement is compiled without parameter +values present. + +""") + + +def _get_crud_params(compiler, stmt, **kw): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately + populating the ResultProxy's prefetch_cols() and postfetch_cols() + collections. + + """ + + compiler.postfetch = [] + compiler.prefetch = [] + compiler.returning = [] + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if compiler.column_keys is None and stmt.parameters is None: + return [ + (c, _create_bind_param( + compiler, c, None, required=True)) + for c in stmt.table.columns + ] + + if stmt._has_multi_parameters: + stmt_parameters = stmt.parameters[0] + else: + stmt_parameters = stmt.parameters + + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + _column_as_key, _getattr_col_key, _col_bind_name = \ + _key_getters_for_crud_column(compiler) + + # if we have statement parameters - set defaults in the + # compiled params + if compiler.column_keys is None: + parameters = {} + else: + parameters = dict((_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or + key not in stmt_parameters) + + # create a list of column assignment clauses as tuples + values = [] + + if stmt_parameters is not None: + _get_stmt_parameters_params( + compiler, + parameters, stmt_parameters, _column_as_key, values, kw) + + check_columns = {} + + # special logic that only occurs for multi-table UPDATE + # statements + if compiler.isupdate and stmt._extra_froms and stmt_parameters: + _get_multitable_params( + compiler, stmt, stmt_parameters, check_columns, + _col_bind_name, _getattr_col_key, values, kw) + + if compiler.isinsert and stmt.select_names: + _scan_insert_from_select_cols( + compiler, stmt, parameters, + _getattr_col_key, _column_as_key, + _col_bind_name, check_columns, values, kw) + else: + _scan_cols( + compiler, stmt, parameters, + _getattr_col_key, _column_as_key, + _col_bind_name, check_columns, values, kw) + + if parameters and stmt_parameters: + check = set(parameters).intersection( + _column_as_key(k) for k in stmt.parameters + ).difference(check_columns) + if check: + raise exc.CompileError( + "Unconsumed column names: %s" % + (", ".join("%s" % c for c in check)) + ) + + if stmt._has_multi_parameters: + values = _extend_values_for_multiparams(compiler, stmt, values, kw) + + return values + + +def _create_bind_param( + compiler, col, value, process=True, required=False, name=None): + if name is None: + name = col.key + bindparam = elements.BindParameter(name, value, + type_=col.type, required=required) + bindparam._is_crud = True + if process: + bindparam = bindparam._compiler_dispatch(compiler) + return bindparam + + +def _key_getters_for_crud_column(compiler): + if compiler.isupdate and compiler.statement._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(compiler.statement._extra_froms) + + def _column_as_key(key): + str_key = elements._column_as_key(key) + if hasattr(key, 'table') and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = elements._column_as_key + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + +def _scan_insert_from_select_cols( + compiler, stmt, parameters, _getattr_col_key, + _column_as_key, _col_bind_name, check_columns, values, kw): + + need_pks, implicit_returning, \ + implicit_return_defaults, postfetch_lastrowid = \ + _get_returning_modifiers(compiler, stmt) + + cols = [stmt.table.c[_column_as_key(name)] + for name in stmt.select_names] + + compiler._insert_from_select = stmt.select + + add_select_cols = [] + if stmt.include_insert_from_select_defaults: + col_set = set(cols) + for col in stmt.table.columns: + if col not in col_set and col.default: + cols.append(col) + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + parameters.pop(col_key) + values.append((c, None)) + else: + _append_param_insert_select_hasdefault( + compiler, stmt, c, add_select_cols, kw) + + if add_select_cols: + values.extend(add_select_cols) + compiler._insert_from_select = compiler._insert_from_select._generate() + compiler._insert_from_select._raw_columns += tuple( + expr for col, expr in add_select_cols) + + +def _scan_cols( + compiler, stmt, parameters, _getattr_col_key, + _column_as_key, _col_bind_name, check_columns, values, kw): + + need_pks, implicit_returning, \ + implicit_return_defaults, postfetch_lastrowid = \ + _get_returning_modifiers(compiler, stmt) + + cols = stmt.table.columns + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + + _append_param_parameter( + compiler, stmt, c, col_key, parameters, _col_bind_name, + implicit_returning, implicit_return_defaults, values, kw) + + elif compiler.isinsert: + if c.primary_key and \ + need_pks and \ + ( + implicit_returning or + not postfetch_lastrowid or + c is not stmt.table._autoincrement_column + ): + + if implicit_returning: + _append_param_insert_pk_returning( + compiler, stmt, c, values, kw) + else: + _append_param_insert_pk(compiler, stmt, c, values, kw) + + elif c.default is not None: + + _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, + values, kw) + + elif c.server_default is not None: + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + + elif compiler.isupdate: + _append_param_update( + compiler, stmt, c, implicit_return_defaults, values, kw) + + +def _append_param_parameter( + compiler, stmt, c, col_key, parameters, _col_bind_name, + implicit_returning, implicit_return_defaults, values, kw): + value = parameters.pop(col_key) + if elements._is_literal(value): + value = _create_bind_param( + compiler, c, value, required=value is REQUIRED, + name=_col_bind_name(c) + if not stmt._has_multi_parameters + else "%s_0" % _col_bind_name(c) + ) + else: + if isinstance(value, elements.BindParameter) and \ + value.type._isnull: + value = value._clone() + value.type = c.type + + if c.primary_key and implicit_returning: + compiler.returning.append(c) + value = compiler.process(value.self_group(), **kw) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + value = compiler.process(value.self_group(), **kw) + else: + compiler.postfetch.append(c) + value = compiler.process(value.self_group(), **kw) + values.append((c, value)) + + +def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): + if c.default is not None: + if c.default.is_sequence: + if compiler.dialect.supports_sequences and \ + (not c.default.optional or + not compiler.dialect.sequences_optional): + proc = compiler.process(c.default, **kw) + values.append((c, proc)) + compiler.returning.append(c) + elif c.default.is_clause_element: + values.append( + (c, compiler.process( + c.default.arg.self_group(), **kw)) + ) + compiler.returning.append(c) + else: + values.append( + (c, _create_bind_param(compiler, c, None)) + ) + compiler.prefetch.append(c) + else: + compiler.returning.append(c) + + +def _append_param_insert_pk(compiler, stmt, c, values, kw): + if ( + (c.default is not None and + (not c.default.is_sequence or + compiler.dialect.supports_sequences)) or + c is stmt.table._autoincrement_column and + (compiler.dialect.supports_sequences or + compiler.dialect. + preexecute_autoincrement_sequences) + ): + values.append( + (c, _create_bind_param(compiler, c, None)) + ) + + compiler.prefetch.append(c) + + +def _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw): + + if c.default.is_sequence: + if compiler.dialect.supports_sequences and \ + (not c.default.optional or + not compiler.dialect.sequences_optional): + proc = compiler.process(c.default, **kw) + values.append((c, proc)) + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + elif c.default.is_clause_element: + proc = compiler.process(c.default.arg.self_group(), **kw) + values.append((c, proc)) + + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + # don't add primary key column to postfetch + compiler.postfetch.append(c) + else: + values.append( + (c, _create_bind_param(compiler, c, None)) + ) + compiler.prefetch.append(c) + + +def _append_param_insert_select_hasdefault( + compiler, stmt, c, values, kw): + + if c.default.is_sequence: + if compiler.dialect.supports_sequences and \ + (not c.default.optional or + not compiler.dialect.sequences_optional): + proc = c.default + values.append((c, proc)) + elif c.default.is_clause_element: + proc = c.default.arg.self_group() + values.append((c, proc)) + else: + values.append( + (c, _create_bind_param(compiler, c, None, process=False)) + ) + compiler.prefetch.append(c) + + +def _append_param_update( + compiler, stmt, c, implicit_return_defaults, values, kw): + + if c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + (c, compiler.process( + c.onupdate.arg.self_group(), **kw)) + ) + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + else: + compiler.postfetch.append(c) + else: + values.append( + (c, _create_bind_param(compiler, c, None)) + ) + compiler.prefetch.append(c) + elif c.server_onupdate is not None: + if implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + else: + compiler.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + compiler.returning.append(c) + + +def _get_multitable_params( + compiler, stmt, stmt_parameters, check_columns, + _col_bind_name, _getattr_col_key, values, kw): + + normalized_params = dict( + (elements._clause_element_as_expr(c), param) + for c, param in stmt_parameters.items() + ) + affected_tables = set() + for t in stmt._extra_froms: + for c in t.c: + if c in normalized_params: + affected_tables.add(t) + check_columns[_getattr_col_key(c)] = c + value = normalized_params[c] + if elements._is_literal(value): + value = _create_bind_param( + compiler, c, value, required=value is REQUIRED, + name=_col_bind_name(c)) + else: + compiler.postfetch.append(c) + value = compiler.process(value.self_group(), **kw) + values.append((c, value)) + # determine tables which are actually to be updated - process onupdate + # and server_onupdate for these + for t in affected_tables: + for c in t.c: + if c in normalized_params: + continue + elif (c.onupdate is not None and not + c.onupdate.is_sequence): + if c.onupdate.is_clause_element: + values.append( + (c, compiler.process( + c.onupdate.arg.self_group(), + **kw) + ) + ) + compiler.postfetch.append(c) + else: + values.append( + (c, _create_bind_param( + compiler, c, None, name=_col_bind_name(c) + ) + ) + ) + compiler.prefetch.append(c) + elif c.server_onupdate is not None: + compiler.postfetch.append(c) + + +def _extend_values_for_multiparams(compiler, stmt, values, kw): + values_0 = values + values = [values] + + values.extend( + [ + ( + c, + (_create_bind_param( + compiler, c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) if elements._is_literal(row[c.key]) + else compiler.process( + row[c.key].self_group(), **kw)) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) + ) + return values + + +def _get_stmt_parameters_params( + compiler, parameters, stmt_parameters, _column_as_key, values, kw): + for k, v in stmt_parameters.items(): + colkey = _column_as_key(k) + if colkey is not None: + parameters.setdefault(colkey, v) + else: + # a non-Column expression on the left side; + # add it to values() in an "as-is" state, + # coercing right side to bound param + if elements._is_literal(v): + v = compiler.process( + elements.BindParameter(None, v, type_=k.type), + **kw) + else: + v = compiler.process(v.self_group(), **kw) + + values.append((k, v)) + + +def _get_returning_modifiers(compiler, stmt): + need_pks = compiler.isinsert and \ + not compiler.inline and \ + not stmt._returning and \ + not stmt._has_multi_parameters + + implicit_returning = need_pks and \ + compiler.dialect.implicit_returning and \ + stmt.table.implicit_returning + + if compiler.isinsert: + implicit_return_defaults = (implicit_returning and + stmt._return_defaults) + elif compiler.isupdate: + implicit_return_defaults = (compiler.dialect.implicit_returning and + stmt.table.implicit_returning and + stmt._return_defaults) + else: + implicit_return_defaults = False + + if implicit_return_defaults: + if stmt._return_defaults is True: + implicit_return_defaults = set(stmt.table.c) + else: + implicit_return_defaults = set(stmt._return_defaults) + + postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid + + return need_pks, implicit_returning, \ + implicit_return_defaults, postfetch_lastrowid diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 1934d0776..9f2ce7ce3 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -475,6 +475,7 @@ class Insert(ValuesBase): ValuesBase.__init__(self, table, values, prefixes) self._bind = bind self.select = self.select_names = None + self.include_insert_from_select_defaults = False self.inline = inline self._returning = returning self._validate_dialect_kwargs(dialect_kw) @@ -487,7 +488,7 @@ class Insert(ValuesBase): return () @_generative - def from_select(self, names, select): + def from_select(self, names, select, include_defaults=True): """Return a new :class:`.Insert` construct which represents an ``INSERT...FROM SELECT`` statement. @@ -506,6 +507,21 @@ class Insert(ValuesBase): is not checked before passing along to the database, the database would normally raise an exception if these column lists don't correspond. + :param include_defaults: if True, non-server default values and + SQL expressions as specified on :class:`.Column` objects + (as documented in :ref:`metadata_defaults_toplevel`) not + otherwise specified in the list of names will be rendered + into the INSERT and SELECT statements, so that these values are also + included in the data to be inserted. + + .. note:: A Python-side default that uses a Python callable function + will only be invoked **once** for the whole statement, and **not + per row**. + + .. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders + Python-side and SQL expression column defaults into the + SELECT statement for columns otherwise not included in the + list of column names. .. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT implies that the :paramref:`.insert.inline` flag is set to @@ -514,13 +530,6 @@ class Insert(ValuesBase): deals with an arbitrary number of rows, so the :attr:`.ResultProxy.inserted_primary_key` accessor does not apply. - .. note:: - - A SELECT..INSERT construct in SQL has no VALUES clause. Therefore - :class:`.Column` objects which utilize Python-side defaults - (e.g. as described at :ref:`metadata_defaults_toplevel`) - will **not** take effect when using :meth:`.Insert.from_select`. - .. versionadded:: 0.8.3 """ @@ -533,6 +542,7 @@ class Insert(ValuesBase): self.select_names = names self.inline = True + self.include_insert_from_select_defaults = include_defaults self.select = _interpret_as_select(select) def _copy_internals(self, clone=_clone, **kw): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8ec0aa700..fa9b66024 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -228,6 +228,7 @@ class ClauseElement(Visitable): is_selectable = False is_clause_element = True + description = None _order_by_label_element = None _is_from_container = False @@ -540,7 +541,7 @@ class ClauseElement(Visitable): __nonzero__ = __bool__ def __repr__(self): - friendly = getattr(self, 'description', None) + friendly = self.description if friendly is None: return object.__repr__(self) else: @@ -860,6 +861,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): expressions and function calls. """ + while self._is_clone_of is not None: + self = self._is_clone_of + return _anonymous_label( '%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon')) ) @@ -1616,10 +1620,10 @@ class Null(ColumnElement): return type_api.NULLTYPE @classmethod - def _singleton(cls): + def _instance(cls): """Return a constant :class:`.Null` construct.""" - return NULL + return Null() def compare(self, other): return isinstance(other, Null) @@ -1640,11 +1644,11 @@ class False_(ColumnElement): return type_api.BOOLEANTYPE def _negate(self): - return TRUE + return True_() @classmethod - def _singleton(cls): - """Return a constant :class:`.False_` construct. + def _instance(cls): + """Return a :class:`.False_` construct. E.g.:: @@ -1678,7 +1682,7 @@ class False_(ColumnElement): """ - return FALSE + return False_() def compare(self, other): return isinstance(other, False_) @@ -1699,17 +1703,17 @@ class True_(ColumnElement): return type_api.BOOLEANTYPE def _negate(self): - return FALSE + return False_() @classmethod def _ifnone(cls, other): if other is None: - return cls._singleton() + return cls._instance() else: return other @classmethod - def _singleton(cls): + def _instance(cls): """Return a constant :class:`.True_` construct. E.g.:: @@ -1744,15 +1748,11 @@ class True_(ColumnElement): """ - return TRUE + return True_() def compare(self, other): return isinstance(other, True_) -NULL = Null() -FALSE = False_() -TRUE = True_() - class ClauseList(ClauseElement): """Describe a list of clauses, separated by an operator. @@ -2782,6 +2782,10 @@ class Grouping(ColumnElement): return self @property + def _key_label(self): + return self._label + + @property def _label(self): return getattr(self.element, '_label', None) or self.anon_label @@ -2888,6 +2892,120 @@ class Over(ColumnElement): )) +class FunctionFilter(ColumnElement): + """Represent a function FILTER clause. + + This is a special operator against aggregate and window functions, + which controls which rows are passed to it. + It's supported only by certain database backends. + + Invocation of :class:`.FunctionFilter` is via + :meth:`.FunctionElement.filter`:: + + func.count(1).filter(True) + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.FunctionElement.filter` + + """ + __visit_name__ = 'funcfilter' + + criterion = None + + def __init__(self, func, *criterion): + """Produce a :class:`.FunctionFilter` object against a function. + + Used against aggregate and window functions, + for database backends that support the "FILTER" clause. + + E.g.:: + + from sqlalchemy import funcfilter + funcfilter(func.count(1), MyClass.name == 'some name') + + Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')". + + This function is also available from the :data:`~.expression.func` + construct itself via the :meth:`.FunctionElement.filter` method. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.FunctionElement.filter` + + + """ + self.func = func + self.filter(*criterion) + + def filter(self, *criterion): + """Produce an additional FILTER against the function. + + This method adds additional criteria to the initial criteria + set up by :meth:`.FunctionElement.filter`. + + Multiple criteria are joined together at SQL render time + via ``AND``. + + + """ + + for criterion in list(criterion): + criterion = _expression_literal_as_text(criterion) + + if self.criterion is not None: + self.criterion = self.criterion & criterion + else: + self.criterion = criterion + + return self + + def over(self, partition_by=None, order_by=None): + """Produce an OVER clause against this filtered function. + + Used against aggregate or so-called "window" functions, + for database backends that support window functions. + + The expression:: + + func.rank().filter(MyClass.y > 5).over(order_by='x') + + is shorthand for:: + + from sqlalchemy import over, funcfilter + over(funcfilter(func.rank(), MyClass.y > 5), order_by='x') + + See :func:`~.expression.over` for a full description. + + """ + return Over(self, partition_by=partition_by, order_by=order_by) + + @util.memoized_property + def type(self): + return self.func.type + + def get_children(self, **kwargs): + return [c for c in + (self.func, self.criterion) + if c is not None] + + def _copy_internals(self, clone=_clone, **kw): + self.func = clone(self.func, **kw) + if self.criterion is not None: + self.criterion = clone(self.criterion, **kw) + + @property + def _from_objects(self): + return list(itertools.chain( + *[c._from_objects for c in (self.func, self.criterion) + if c is not None] + )) + + class Label(ColumnElement): """Represents a column label (AS). @@ -3491,7 +3609,7 @@ def _string_or_unprintable(element): else: try: return str(element) - except: + except Exception: return "unprintable element %r" % element diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index d96f048b9..2ffc5468c 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -36,7 +36,7 @@ from .elements import ClauseElement, ColumnElement,\ True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \ Grouping, not_, \ collate, literal_column, between,\ - literal, outparam, type_coerce, ClauseList + literal, outparam, type_coerce, ClauseList, FunctionFilter from .elements import SavepointClause, RollbackToSavepointClause, \ ReleaseSavepointClause @@ -89,14 +89,16 @@ asc = public_factory(UnaryExpression._create_asc, ".expression.asc") desc = public_factory(UnaryExpression._create_desc, ".expression.desc") distinct = public_factory( UnaryExpression._create_distinct, ".expression.distinct") -true = public_factory(True_._singleton, ".expression.true") -false = public_factory(False_._singleton, ".expression.false") -null = public_factory(Null._singleton, ".expression.null") +true = public_factory(True_._instance, ".expression.true") +false = public_factory(False_._instance, ".expression.false") +null = public_factory(Null._instance, ".expression.null") join = public_factory(Join._create_join, ".expression.join") outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin") insert = public_factory(Insert, ".expression.insert") update = public_factory(Update, ".expression.update") delete = public_factory(Delete, ".expression.delete") +funcfilter = public_factory( + FunctionFilter, ".expression.funcfilter") # internal functions still being called from tests and the ORM, diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 7efb1e916..9280c7d60 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -12,7 +12,7 @@ from . import sqltypes, schema from .base import Executable, ColumnCollection from .elements import ClauseList, Cast, Extract, _literal_as_binds, \ literal_column, _type_from_args, ColumnElement, _clone,\ - Over, BindParameter + Over, BindParameter, FunctionFilter from .selectable import FromClause, Select, Alias from . import operators @@ -116,6 +116,35 @@ class FunctionElement(Executable, ColumnElement, FromClause): """ return Over(self, partition_by=partition_by, order_by=order_by) + def filter(self, *criterion): + """Produce a FILTER clause against this function. + + Used against aggregate and window functions, + for database backends that support the "FILTER" clause. + + The expression:: + + func.count(1).filter(True) + + is shorthand for:: + + from sqlalchemy import funcfilter + funcfilter(func.count(1), True) + + .. versionadded:: 1.0.0 + + .. seealso:: + + :class:`.FunctionFilter` + + :func:`.funcfilter` + + + """ + if not criterion: + return self + return FunctionFilter(self, *criterion) + @property def _from_objects(self): return self.clauses._from_objects diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index d9fd37f92..96cabbf4f 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -412,8 +412,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause): table.dispatch.after_parent_attach(table, metadata) return table except: - metadata._remove_table(name, schema) - raise + with util.safe_reraise(): + metadata._remove_table(name, schema) @property @util.deprecated('0.9', 'Use ``table.schema.quote``') @@ -728,7 +728,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): checkfirst=checkfirst) def tometadata(self, metadata, schema=RETAIN_SCHEMA, - referred_schema_fn=None): + referred_schema_fn=None, name=None): """Return a copy of this :class:`.Table` associated with a different :class:`.MetaData`. @@ -785,13 +785,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause): .. versionadded:: 0.9.2 - """ + :param name: optional string name indicating the target table name. + If not specified or None, the table name is retained. This allows + a :class:`.Table` to be copied to the same :class:`.MetaData` target + with a new name. + + .. versionadded:: 1.0.0 + """ + if name is None: + name = self.name if schema is RETAIN_SCHEMA: schema = self.schema elif schema is None: schema = metadata.schema - key = _get_table_key(self.name, schema) + key = _get_table_key(name, schema) if key in metadata.tables: util.warn("Table '%s' already exists within the given " "MetaData - not copying." % self.description) @@ -801,7 +809,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): for c in self.columns: args.append(c.copy(schema=schema)) table = Table( - self.name, metadata, schema=schema, + name, metadata, schema=schema, *args, **self.kwargs ) for c in self.constraints: @@ -1061,8 +1069,8 @@ class Column(SchemaItem, ColumnClause): conditionally rendered differently on different backends, consider custom compilation rules for :class:`.CreateColumn`. - ..versionadded:: 0.8.3 Added the ``system=True`` parameter to - :class:`.Column`. + .. versionadded:: 0.8.3 Added the ``system=True`` parameter to + :class:`.Column`. """ @@ -1222,8 +1230,10 @@ class Column(SchemaItem, ColumnClause): existing = getattr(self, 'table', None) if existing is not None and existing is not table: raise exc.ArgumentError( - "Column object already assigned to Table '%s'" % - existing.description) + "Column object '%s' already assigned to Table '%s'" % ( + self.key, + existing.description + )) if self.key in table._columns: col = table._columns.get(self.key) @@ -1547,7 +1557,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): ) return self._schema_item_copy(fk) - def _get_colspec(self, schema=None): + def _get_colspec(self, schema=None, table_name=None): """Return a string based 'column specification' for this :class:`.ForeignKey`. @@ -1557,7 +1567,15 @@ class ForeignKey(DialectKWArgs, SchemaItem): """ if schema: _schema, tname, colname = self._column_tokens + if table_name is not None: + tname = table_name return "%s.%s.%s" % (schema, tname, colname) + elif table_name: + schema, tname, colname = self._column_tokens + if schema: + return "%s.%s.%s" % (schema, table_name, colname) + else: + return "%s.%s" % (table_name, colname) elif self._table_column is not None: return "%s.%s" % ( self._table_column.table.fullname, self._table_column.key) @@ -2649,10 +2667,15 @@ class ForeignKeyConstraint(Constraint): event.listen(table.metadata, "before_drop", ddl.DropConstraint(self, on=supports_alter)) - def copy(self, schema=None, **kw): + def copy(self, schema=None, target_table=None, **kw): fkc = ForeignKeyConstraint( [x.parent.key for x in self._elements.values()], - [x._get_colspec(schema=schema) + [x._get_colspec( + schema=schema, + table_name=target_table.name + if target_table is not None + and x._table_key() == x.parent.table.key + else None) for x in self._elements.values()], name=self.name, onupdate=self.onupdate, diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 9e8cb3bc5..8198a6733 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -746,6 +746,33 @@ class Join(FromClause): providing a "natural join". """ + constraints = cls._joincond_scan_left_right( + a, a_subset, b, consider_as_foreign_keys) + + if len(constraints) > 1: + cls._joincond_trim_constraints( + a, b, constraints, consider_as_foreign_keys) + + if len(constraints) == 0: + if isinstance(b, FromGrouping): + hint = " Perhaps you meant to convert the right side to a "\ + "subquery using alias()?" + else: + hint = "" + raise exc.NoForeignKeysError( + "Can't find any foreign key relationships " + "between '%s' and '%s'.%s" % + (a.description, b.description, hint)) + + crit = [(x == y) for x, y in list(constraints.values())[0]] + if len(crit) == 1: + return (crit[0]) + else: + return and_(*crit) + + @classmethod + def _joincond_scan_left_right( + cls, a, a_subset, b, consider_as_foreign_keys): constraints = collections.defaultdict(list) for left in (a_subset, a): @@ -780,57 +807,41 @@ class Join(FromClause): if nrte.table_name == b.name: raise else: - # this is totally covered. can't get - # coverage to mark it. continue if col is not None: constraints[fk.constraint].append((col, fk.parent)) if constraints: break + return constraints + @classmethod + def _joincond_trim_constraints( + cls, a, b, constraints, consider_as_foreign_keys): + # more than one constraint matched. narrow down the list + # to include just those FKCs that match exactly to + # "consider_as_foreign_keys". + if consider_as_foreign_keys: + for const in list(constraints): + if set(f.parent for f in const.elements) != set( + consider_as_foreign_keys): + del constraints[const] + + # if still multiple constraints, but + # they all refer to the exact same end result, use it. if len(constraints) > 1: - # more than one constraint matched. narrow down the list - # to include just those FKCs that match exactly to - # "consider_as_foreign_keys". - if consider_as_foreign_keys: - for const in list(constraints): - if set(f.parent for f in const.elements) != set( - consider_as_foreign_keys): - del constraints[const] - - # if still multiple constraints, but - # they all refer to the exact same end result, use it. - if len(constraints) > 1: - dedupe = set(tuple(crit) for crit in constraints.values()) - if len(dedupe) == 1: - key = list(constraints)[0] - constraints = {key: constraints[key]} - - if len(constraints) != 1: - raise exc.AmbiguousForeignKeysError( - "Can't determine join between '%s' and '%s'; " - "tables have more than one foreign key " - "constraint relationship between them. " - "Please specify the 'onclause' of this " - "join explicitly." % (a.description, b.description)) - - if len(constraints) == 0: - if isinstance(b, FromGrouping): - hint = " Perhaps you meant to convert the right side to a "\ - "subquery using alias()?" - else: - hint = "" - raise exc.NoForeignKeysError( - "Can't find any foreign key relationships " - "between '%s' and '%s'.%s" % - (a.description, b.description, hint)) - - crit = [(x == y) for x, y in list(constraints.values())[0]] - if len(crit) == 1: - return (crit[0]) - else: - return and_(*crit) + dedupe = set(tuple(crit) for crit in constraints.values()) + if len(dedupe) == 1: + key = list(constraints)[0] + constraints = {key: constraints[key]} + + if len(constraints) != 1: + raise exc.AmbiguousForeignKeysError( + "Can't determine join between '%s' and '%s'; " + "tables have more than one foreign key " + "constraint relationship between them. " + "Please specify the 'onclause' of this " + "join explicitly." % (a.description, b.description)) def select(self, whereclause=None, **kwargs): """Create a :class:`.Select` from this :class:`.Join`. @@ -2153,6 +2164,7 @@ class Select(HasPrefixes, GenerativeSelect): _prefixes = () _hints = util.immutabledict() + _statement_hints = () _distinct = False _from_cloned = None _correlate = () @@ -2525,10 +2537,30 @@ class Select(HasPrefixes, GenerativeSelect): return self._get_display_froms() + def with_statement_hint(self, text, dialect_name='*'): + """add a statement hint to this :class:`.Select`. + + This method is similar to :meth:`.Select.with_hint` except that + it does not require an individual table, and instead applies to the + statement as a whole. + + Hints here are specific to the backend database and may include + directives such as isolation levels, file directives, fetch directives, + etc. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Select.with_hint` + + """ + return self.with_hint(None, text, dialect_name) + @_generative def with_hint(self, selectable, text, dialect_name='*'): - """Add an indexing hint for the given selectable to this - :class:`.Select`. + """Add an indexing or other executional context hint for the given + selectable to this :class:`.Select`. The text of the hint is rendered in the appropriate location for the database backend in use, relative @@ -2540,7 +2572,7 @@ class Select(HasPrefixes, GenerativeSelect): following:: select([mytable]).\\ - with_hint(mytable, "+ index(%(name)s ix_mytable)") + with_hint(mytable, "index(%(name)s ix_mytable)") Would render SQL as:: @@ -2551,13 +2583,19 @@ class Select(HasPrefixes, GenerativeSelect): and Sybase simultaneously:: select([mytable]).\\ - with_hint( - mytable, "+ index(%(name)s ix_mytable)", 'oracle').\\ + with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\\ with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') + .. seealso:: + + :meth:`.Select.with_statement_hint` + """ - self._hints = self._hints.union( - {(selectable, dialect_name): text}) + if selectable is None: + self._statement_hints += ((dialect_name, text), ) + else: + self._hints = self._hints.union( + {(selectable, dialect_name): text}) @property def type(self): diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 67c13231e..1284f9c2a 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -37,8 +37,6 @@ class ConnectionKiller(object): def _safe(self, fn): try: fn() - except (SystemExit, KeyboardInterrupt): - raise except Exception as e: warnings.warn( "testing_reaper couldn't " @@ -168,8 +166,6 @@ class ReconnectFixture(object): def _safe(self, fn): try: fn() - except (SystemExit, KeyboardInterrupt): - raise except Exception as e: warnings.warn( "ReconnectFixture couldn't " diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 283d89e36..f94724608 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -133,7 +133,7 @@ class compound(object): name, fail._as_string(config), str(ex)))) break else: - raise ex + util.raise_from_cause(ex) def _expect_success(self, config, name='block'): if not self.fails: @@ -178,8 +178,7 @@ class Predicate(object): @classmethod def as_predicate(cls, predicate, description=None): if isinstance(predicate, compound): - return cls.as_predicate(predicate.fails.union(predicate.skips)) - + return cls.as_predicate(predicate.enabled_for_config, description) elif isinstance(predicate, Predicate): if description and predicate.description is None: predicate.description = description diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 0bcdad959..c8f7fdf30 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -120,7 +120,7 @@ def _pg_create_db(cfg, eng, ident): isolation_level="AUTOCOMMIT") as conn: try: _pg_drop_db(cfg, conn, ident) - except: + except Exception: pass currentdb = conn.scalar("select current_database()") conn.execute("CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb)) @@ -131,7 +131,7 @@ def _mysql_create_db(cfg, eng, ident): with eng.connect() as conn: try: _mysql_drop_db(cfg, conn, ident) - except: + except Exception: pass conn.execute("CREATE DATABASE %s" % ident) conn.execute("CREATE DATABASE %s_test_schema" % ident) @@ -173,15 +173,15 @@ def _mysql_drop_db(cfg, eng, ident): with eng.connect() as conn: try: conn.execute("DROP DATABASE %s_test_schema" % ident) - except: + except Exception: pass try: conn.execute("DROP DATABASE %s_test_schema_2" % ident) - except: + except Exception: pass try: conn.execute("DROP DATABASE %s" % ident) - except: + except Exception: pass diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index a04bcbbdd..da3e3128a 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -314,6 +314,20 @@ class SuiteRequirements(Requirements): return exclusions.open() @property + def temp_table_reflection(self): + return exclusions.open() + + @property + def temp_table_names(self): + """target dialect supports listing of temporary table names""" + return exclusions.closed() + + @property + def temporary_views(self): + """target database supports temporary views""" + return exclusions.closed() + + @property def index_reflection(self): return exclusions.open() diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index 92d3d93e5..38519dfb9 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -4,7 +4,7 @@ from .. import exclusions from ..assertions import eq_ from .. import engines -from sqlalchemy import Integer, String, select, util +from sqlalchemy import Integer, String, select, literal_column, literal from ..schema import Table, Column @@ -90,6 +90,13 @@ class InsertBehaviorTest(fixtures.TablesTest): Column('id', Integer, primary_key=True, autoincrement=False), Column('data', String(50)) ) + Table('includes_defaults', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', String(50)), + Column('x', Integer, default=5), + Column('y', Integer, + default=literal_column("2", type_=Integer) + literal(2))) def test_autoclose_on_insert(self): if requirements.returning.enabled: @@ -158,6 +165,34 @@ class InsertBehaviorTest(fixtures.TablesTest): ("data3", ), ("data3", )] ) + @requirements.insert_from_select + def test_insert_from_select_with_defaults(self): + table = self.tables.includes_defaults + config.db.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ] + ) + + config.db.execute( + table.insert(inline=True). + from_select(("id", "data",), + select([table.c.id + 5, table.c.data]). + where(table.c.data.in_(["data2", "data3"])) + ), + ) + + eq_( + config.db.execute( + select([table]).order_by(table.c.data, table.c.id) + ).fetchall(), + [(1, 'data1', 5, 4), (2, 'data2', 5, 4), + (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)] + ) + class ReturningTest(fixtures.TablesTest): run_create_tables = 'each' diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 575a38db9..08b858b47 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -95,6 +95,39 @@ class ComponentReflectionTest(fixtures.TablesTest): cls.define_index(metadata, users) if testing.requires.view_column_reflection.enabled: cls.define_views(metadata, schema) + if not schema and testing.requires.temp_table_reflection.enabled: + cls.define_temp_tables(metadata) + + @classmethod + def define_temp_tables(cls, metadata): + # cheat a bit, we should fix this with some dialect-level + # temp table fixture + if testing.against("oracle"): + kw = { + 'prefixes': ["GLOBAL TEMPORARY"], + 'oracle_on_commit': 'PRESERVE ROWS' + } + else: + kw = { + 'prefixes': ["TEMPORARY"], + } + + user_tmp = Table( + "user_tmp", metadata, + Column("id", sa.INT, primary_key=True), + Column('name', sa.VARCHAR(50)), + Column('foo', sa.INT), + sa.UniqueConstraint('name', name='user_tmp_uq'), + sa.Index("user_tmp_ix", "foo"), + **kw + ) + if testing.requires.view_reflection.enabled and \ + testing.requires.temporary_views.enabled: + event.listen( + user_tmp, "after_create", + DDL("create temporary view user_tmp_v as " + "select * from user_tmp") + ) @classmethod def define_index(cls, metadata, users): @@ -147,6 +180,7 @@ class ComponentReflectionTest(fixtures.TablesTest): users, addresses, dingalings = self.tables.users, \ self.tables.email_addresses, self.tables.dingalings insp = inspect(meta.bind) + if table_type == 'view': table_names = insp.get_view_names(schema) table_names.sort() @@ -162,6 +196,20 @@ class ComponentReflectionTest(fixtures.TablesTest): answer = ['dingalings', 'email_addresses', 'users'] eq_(sorted(table_names), answer) + @testing.requires.temp_table_names + def test_get_temp_table_names(self): + insp = inspect(testing.db) + temp_table_names = insp.get_temp_table_names() + eq_(sorted(temp_table_names), ['user_tmp']) + + @testing.requires.view_reflection + @testing.requires.temp_table_names + @testing.requires.temporary_views + def test_get_temp_view_names(self): + insp = inspect(self.metadata.bind) + temp_table_names = insp.get_temp_view_names() + eq_(sorted(temp_table_names), ['user_tmp_v']) + @testing.requires.table_reflection def test_get_table_names(self): self._test_get_table_names() @@ -294,6 +342,28 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_columns_with_schema(self): self._test_get_columns(schema=testing.config.test_schema) + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self): + meta = MetaData(testing.db) + user_tmp = self.tables.user_tmp + insp = inspect(meta.bind) + cols = insp.get_columns('user_tmp') + self.assert_(len(cols) > 0, len(cols)) + + for i, col in enumerate(user_tmp.columns): + eq_(col.name, cols[i]['name']) + + @testing.requires.temp_table_reflection + @testing.requires.view_column_reflection + @testing.requires.temporary_views + def test_get_temp_view_columns(self): + insp = inspect(self.metadata.bind) + cols = insp.get_columns('user_tmp_v') + eq_( + [col['name'] for col in cols], + ['id', 'name', 'foo'] + ) + @testing.requires.view_column_reflection def test_get_view_columns(self): self._test_get_columns(table_type='view') @@ -426,6 +496,28 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_unique_constraints(self): self._test_get_unique_constraints() + @testing.requires.temp_table_reflection + @testing.requires.unique_constraint_reflection + def test_get_temp_table_unique_constraints(self): + insp = inspect(self.metadata.bind) + reflected = insp.get_unique_constraints('user_tmp') + for refl in reflected: + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + refl.pop('duplicates_index', None) + eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}]) + + @testing.requires.temp_table_reflection + def test_get_temp_table_indexes(self): + insp = inspect(self.metadata.bind) + indexes = insp.get_indexes('user_tmp') + eq_( + # TODO: we need to add better filtering for indexes/uq constraints + # that are doubled up + [idx for idx in indexes if idx['name'] == 'user_tmp_ix'], + [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}] + ) + @testing.requires.unique_constraint_reflection @testing.requires.schemas def test_get_unique_constraints_with_schema(self): @@ -466,6 +558,9 @@ class ComponentReflectionTest(fixtures.TablesTest): ) for orig, refl in zip(uniques, reflected): + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + refl.pop('duplicates_index', None) eq_(orig, refl) @testing.provide_metadata diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index c963b18c3..dfed5b90a 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -33,7 +33,8 @@ from .langhelpers import iterate_attributes, class_hierarchy, \ duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\ classproperty, set_creation_order, warn_exception, warn, NoneType,\ constructor_copy, methods_equivalent, chop_traceback, asint,\ - generic_repr, counter, PluginLoader, hybridmethod, safe_reraise,\ + generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \ + safe_reraise,\ get_callable_argspec, only_once, attrsetter, ellipses_string, \ warn_limited diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 76f85f605..5c17bea88 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -134,7 +134,8 @@ def public_factory(target, location): fn = target.__init__ callable_ = target doc = "Construct a new :class:`.%s` object. \n\n"\ - "This constructor is mirrored as a public API function; see :func:`~%s` "\ + "This constructor is mirrored as a public API function; "\ + "see :func:`~%s` "\ "for a full usage and argument description." % ( target.__name__, location, ) else: @@ -155,6 +156,7 @@ def %(name)s(%(args)s): exec(code, env) decorated = env[location_name] decorated.__doc__ = fn.__doc__ + decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0] if compat.py2k or hasattr(fn, '__func__'): fn.__func__.__doc__ = doc else: @@ -490,7 +492,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): val = getattr(obj, arg, missing) if val is not missing and val != defval: output.append('%s=%r' % (arg, val)) - except: + except Exception: pass if additional_kw: @@ -499,7 +501,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): val = getattr(obj, arg, missing) if val is not missing and val != defval: output.append('%s=%r' % (arg, val)) - except: + except Exception: pass return "%s(%s)" % (obj.__class__.__name__, ", ".join(output)) @@ -1090,10 +1092,23 @@ class classproperty(property): return desc.fget(cls) +class hybridproperty(object): + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner): + if instance is None: + clsval = self.func(owner) + clsval.__doc__ = self.func.__doc__ + return clsval + else: + return self.func(instance) + + class hybridmethod(object): """Decorate a function as cls- or instance- level.""" - def __init__(self, func, expr=None): + def __init__(self, func): self.func = func def __get__(self, instance, owner): @@ -1185,7 +1200,7 @@ def warn_exception(func, *args, **kwargs): """ try: return func(*args, **kwargs) - except: + except Exception: warn("%s('%s') ignored" % sys.exc_info()[0:2]) |
