summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/__init__.py1
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py2
-rw-r--r--lib/sqlalchemy/dialects/mssql/pymssql.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py30
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqlconnector.py48
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py49
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py100
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py80
-rw-r--r--lib/sqlalchemy/dialects/sqlite/__init__.py2
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py85
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlcipher.py116
-rw-r--r--lib/sqlalchemy/engine/base.py159
-rw-r--r--lib/sqlalchemy/engine/interfaces.py18
-rw-r--r--lib/sqlalchemy/engine/reflection.py184
-rw-r--r--lib/sqlalchemy/engine/strategies.py1
-rw-r--r--lib/sqlalchemy/event/attr.py6
-rw-r--r--lib/sqlalchemy/event/registry.py40
-rw-r--r--lib/sqlalchemy/events.py5
-rw-r--r--lib/sqlalchemy/exc.py27
-rw-r--r--lib/sqlalchemy/ext/automap.py46
-rw-r--r--lib/sqlalchemy/ext/declarative/__init__.py117
-rw-r--r--lib/sqlalchemy/ext/declarative/api.py185
-rw-r--r--lib/sqlalchemy/ext/declarative/base.py697
-rw-r--r--lib/sqlalchemy/ext/declarative/clsregistry.py7
-rw-r--r--lib/sqlalchemy/orm/collections.py48
-rw-r--r--lib/sqlalchemy/orm/events.py12
-rw-r--r--lib/sqlalchemy/orm/mapper.py21
-rw-r--r--lib/sqlalchemy/orm/persistence.py47
-rw-r--r--lib/sqlalchemy/orm/query.py81
-rw-r--r--lib/sqlalchemy/orm/relationships.py56
-rw-r--r--lib/sqlalchemy/orm/session.py108
-rw-r--r--lib/sqlalchemy/orm/state.py15
-rw-r--r--lib/sqlalchemy/orm/util.py29
-rw-r--r--lib/sqlalchemy/pool.py20
-rw-r--r--lib/sqlalchemy/sql/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py584
-rw-r--r--lib/sqlalchemy/sql/crud.py530
-rw-r--r--lib/sqlalchemy/sql/dml.py26
-rw-r--r--lib/sqlalchemy/sql/elements.py150
-rw-r--r--lib/sqlalchemy/sql/expression.py10
-rw-r--r--lib/sqlalchemy/sql/functions.py31
-rw-r--r--lib/sqlalchemy/sql/schema.py49
-rw-r--r--lib/sqlalchemy/sql/selectable.py138
-rw-r--r--lib/sqlalchemy/testing/engines.py4
-rw-r--r--lib/sqlalchemy/testing/exclusions.py5
-rw-r--r--lib/sqlalchemy/testing/provision.py10
-rw-r--r--lib/sqlalchemy/testing/requirements.py14
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py37
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py95
-rw-r--r--lib/sqlalchemy/util/__init__.py3
-rw-r--r--lib/sqlalchemy/util/langhelpers.py25
51 files changed, 2912 insertions, 1244 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index 853566172..d184e1fbf 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -25,6 +25,7 @@ from .sql import (
extract,
false,
func,
+ funcfilter,
insert,
intersect,
intersect_all,
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index ba3050ae5..dad02ee0f 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -846,7 +846,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
"SET IDENTITY_INSERT %s OFF" %
self.dialect.identifier_preparer. format_table(
self.compiled.statement.table)))
- except:
+ except Exception:
pass
def get_result_proxy(self):
diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py
index 8f76336ae..b5a1bc566 100644
--- a/lib/sqlalchemy/dialects/mssql/pymssql.py
+++ b/lib/sqlalchemy/dialects/mssql/pymssql.py
@@ -63,7 +63,7 @@ class MSDialect_pymssql(MSDialect):
def _get_server_version_info(self, connection):
vers = connection.scalar("select @@version")
m = re.match(
- r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
+ r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers)
if m:
return tuple(int(x) for x in m.group(1, 2, 3, 4))
else:
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 7ccd59abb..2fb054d0c 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -341,6 +341,29 @@ reflection will not include foreign keys. For these tables, you may supply a
:ref:`mysql_storage_engines`
+.. _mysql_unique_constraints:
+
+MySQL Unique Constraints and Reflection
+---------------------------------------
+
+SQLAlchemy supports both the :class:`.Index` construct with the
+flag ``unique=True``, indicating a UNIQUE index, as well as the
+:class:`.UniqueConstraint` construct, representing a UNIQUE constraint.
+Both objects/syntaxes are supported by MySQL when emitting DDL to create
+these constraints. However, MySQL does not have a unique constraint
+construct that is separate from a unique index; that is, the "UNIQUE"
+constraint on MySQL is equivalent to creating a "UNIQUE INDEX".
+
+When reflecting these constructs, the :meth:`.Inspector.get_indexes`
+and the :meth:`.Inspector.get_unique_constraints` methods will **both**
+return an entry for a UNIQUE index in MySQL. However, when performing
+full table reflection using ``Table(..., autoload=True)``,
+the :class:`.UniqueConstraint` construct is
+**not** part of the fully reflected :class:`.Table` construct under any
+circumstances; this construct is always represented by a :class:`.Index`
+with the ``unique=True`` setting present in the :attr:`.Table.indexes`
+collection.
+
.. _mysql_timestamp_null:
@@ -2317,7 +2340,7 @@ class MySQLDialect(default.DefaultDialect):
# basic operations via autocommit fail.
try:
dbapi_connection.commit()
- except:
+ except Exception:
if self.server_version_info < (3, 23, 15):
args = sys.exc_info()[1].args
if args and args[0] == 1064:
@@ -2329,7 +2352,7 @@ class MySQLDialect(default.DefaultDialect):
try:
dbapi_connection.rollback()
- except:
+ except Exception:
if self.server_version_info < (3, 23, 15):
args = sys.exc_info()[1].args
if args and args[0] == 1064:
@@ -2590,7 +2613,8 @@ class MySQLDialect(default.DefaultDialect):
return [
{
'name': key['name'],
- 'column_names': [col[0] for col in key['columns']]
+ 'column_names': [col[0] for col in key['columns']],
+ 'duplicates_index': key['name'],
}
for key in parsed_state.keys
if key['type'] == 'UNIQUE'
diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
index e51e80005..417e1ad6f 100644
--- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
+++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
@@ -21,6 +21,7 @@ from .base import (MySQLDialect, MySQLExecutionContext,
BIT)
from ... import util
+import re
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
@@ -31,18 +32,34 @@ class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod_binary(self, binary, operator, **kw):
- return self.process(binary.left, **kw) + " %% " + \
- self.process(binary.right, **kw)
+ if self.dialect._mysqlconnector_double_percents:
+ return self.process(binary.left, **kw) + " %% " + \
+ self.process(binary.right, **kw)
+ else:
+ return self.process(binary.left, **kw) + " % " + \
+ self.process(binary.right, **kw)
def post_process_text(self, text):
- return text.replace('%', '%%')
+ if self.dialect._mysqlconnector_double_percents:
+ return text.replace('%', '%%')
+ else:
+ return text
+
+ def escape_literal_column(self, text):
+ if self.dialect._mysqlconnector_double_percents:
+ return text.replace('%', '%%')
+ else:
+ return text
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
- return value.replace("%", "%%")
+ if self.dialect._mysqlconnector_double_percents:
+ return value.replace("%", "%%")
+ else:
+ return value
class _myconnpyBIT(BIT):
@@ -55,8 +72,6 @@ class _myconnpyBIT(BIT):
class MySQLDialect_mysqlconnector(MySQLDialect):
driver = 'mysqlconnector'
- if util.py2k:
- supports_unicode_statements = False
supports_unicode_binds = True
supports_sane_rowcount = True
@@ -77,6 +92,10 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
}
)
+ @util.memoized_property
+ def supports_unicode_statements(self):
+ return util.py3k or self._mysqlconnector_version_info > (2, 0)
+
@classmethod
def dbapi(cls):
from mysql import connector
@@ -103,10 +122,25 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
'client_flags', ClientFlag.get_default())
client_flags |= ClientFlag.FOUND_ROWS
opts['client_flags'] = client_flags
- except:
+ except Exception:
pass
return [[], opts]
+ @util.memoized_property
+ def _mysqlconnector_version_info(self):
+ if self.dbapi and hasattr(self.dbapi, '__version__'):
+ m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?',
+ self.dbapi.__version__)
+ if m:
+ return tuple(
+ int(x)
+ for x in m.group(1, 2, 3)
+ if x is not None)
+
+ @util.memoized_property
+ def _mysqlconnector_double_percents(self):
+ return not util.py3k and self._mysqlconnector_version_info < (2, 0)
+
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = dbapi_con.get_server_version()
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 81a9f1a95..6df38e57e 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -213,6 +213,21 @@ is reflected and the type is reported as ``DATE``, the time-supporting
examining the type of column for use in special Python translations or
for migrating schemas to other database backends.
+Oracle Table Options
+-------------------------
+
+The CREATE TABLE phrase supports the following options with Oracle
+in conjunction with the :class:`.Table` construct:
+
+
+* ``ON COMMIT``::
+
+ Table(
+ "some_table", metadata, ...,
+ prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS')
+
+.. versionadded:: 1.0.0
+
"""
import re
@@ -784,11 +799,22 @@ class OracleDDLCompiler(compiler.DDLCompiler):
return super(OracleDDLCompiler, self).\
visit_create_index(create, include_schema=True)
+ def post_create_table(self, table):
+ table_opts = []
+ opts = table.dialect_options['oracle']
+
+ if opts['on_commit']:
+ on_commit_options = opts['on_commit'].replace("_", " ").upper()
+ table_opts.append('\n ON COMMIT %s' % on_commit_options)
+
+ return ''.join(table_opts)
+
class OracleIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = set([x.lower() for x in RESERVED_WORDS])
- illegal_initial_characters = set(range(0, 10)).union(["_", "$"])
+ illegal_initial_characters = set(
+ (str(dig) for dig in range(0, 10))).union(["_", "$"])
def _bindparam_requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
@@ -842,7 +868,10 @@ class OracleDialect(default.DefaultDialect):
reflection_options = ('oracle_resolve_synonyms', )
construct_arguments = [
- (sa_schema.Table, {"resolve_synonyms": False})
+ (sa_schema.Table, {
+ "resolve_synonyms": False,
+ "on_commit": None
+ })
]
def __init__(self,
@@ -1029,7 +1058,21 @@ class OracleDialect(default.DefaultDialect):
"WHERE nvl(tablespace_name, 'no tablespace') NOT IN "
"('SYSTEM', 'SYSAUX') "
"AND OWNER = :owner "
- "AND IOT_NAME IS NULL")
+ "AND IOT_NAME IS NULL "
+ "AND DURATION IS NULL")
+ cursor = connection.execute(s, owner=schema)
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ schema = self.denormalize_name(self.default_schema_name)
+ s = sql.text(
+ "SELECT table_name FROM all_tables "
+ "WHERE nvl(tablespace_name, 'no tablespace') NOT IN "
+ "('SYSTEM', 'SYSAUX') "
+ "AND OWNER = :owner "
+ "AND IOT_NAME IS NULL "
+ "AND DURATION IS NOT NULL")
cursor = connection.execute(s, owner=schema)
return [self.normalize_name(row[0]) for row in cursor]
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 575d2a6dd..baa640eaa 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -401,6 +401,29 @@ The value passed to the keyword argument will be simply passed through to the
underlying CREATE INDEX command, so it *must* be a valid index type for your
version of PostgreSQL.
+
+.. _postgresql_index_reflection:
+
+Postgresql Index Reflection
+---------------------------
+
+The Postgresql database creates a UNIQUE INDEX implicitly whenever the
+UNIQUE CONSTRAINT construct is used. When inspecting a table using
+:class:`.Inspector`, the :meth:`.Inspector.get_indexes`
+and the :meth:`.Inspector.get_unique_constraints` will report on these
+two constructs distinctly; in the case of the index, the key
+``duplicates_constraint`` will be present in the index entry if it is
+detected as mirroring a constraint. When performing reflection using
+``Table(..., autoload=True)``, the UNIQUE INDEX is **not** returned
+in :attr:`.Table.indexes` when it is detected as mirroring a
+:class:`.UniqueConstraint` in the :attr:`.Table.constraints` collection.
+
+.. versionchanged:: 1.0.0 - :class:`.Table` reflection now includes
+ :class:`.UniqueConstraint` objects present in the :attr:`.Table.constraints`
+ collection; the Postgresql backend will no longer include a "mirrored"
+ :class:`.Index` construct in :attr:`.Table.indexes` if it is detected
+ as corresponding to a unique constraint.
+
Special Reflection Options
--------------------------
@@ -1679,6 +1702,19 @@ class PGInspector(reflection.Inspector):
schema = schema or self.default_schema_name
return self.dialect._load_enums(self.bind, schema)
+ def get_foreign_table_names(self, schema=None):
+ """Return a list of FOREIGN TABLE names.
+
+ Behavior is similar to that of :meth:`.Inspector.get_table_names`,
+ except that the list is limited to those tables tha report a
+ ``relkind`` value of ``f``.
+
+ .. versionadded:: 1.0.0
+
+ """
+ schema = schema or self.default_schema_name
+ return self.dialect._get_foreign_table_names(self.bind, schema)
+
class CreateEnumType(schema._CreateDropBase):
__visit_name__ = "create_enum_type"
@@ -2024,7 +2060,7 @@ class PGDialect(default.DefaultDialect):
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE (%s)
- AND c.relname = :table_name AND c.relkind in ('r','v')
+ AND c.relname = :table_name AND c.relkind in ('r', 'v', 'm', 'f')
""" % schema_where_clause
# Since we're binding to unicode, table_name and schema_name must be
# unicode.
@@ -2078,6 +2114,24 @@ class PGDialect(default.DefaultDialect):
return [row[0] for row in result]
@reflection.cache
+ def _get_foreign_table_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.default_schema_name
+
+ result = connection.execute(
+ sql.text("SELECT relname FROM pg_class c "
+ "WHERE relkind = 'f' "
+ "AND '%s' = (select nspname from pg_namespace n "
+ "where n.oid = c.relnamespace) " %
+ current_schema,
+ typemap={'relname': sqltypes.Unicode}
+ )
+ )
+ return [row[0] for row in result]
+
+ @reflection.cache
def get_view_names(self, connection, schema=None, **kw):
if schema is not None:
current_schema = schema
@@ -2086,7 +2140,7 @@ class PGDialect(default.DefaultDialect):
s = """
SELECT relname
FROM pg_class c
- WHERE relkind = 'v'
+ WHERE relkind IN ('m', 'v')
AND '%(schema)s' = (select nspname from pg_namespace n
where n.oid = c.relnamespace)
""" % dict(schema=current_schema)
@@ -2439,16 +2493,21 @@ class PGDialect(default.DefaultDialect):
SELECT
i.relname as relname,
ix.indisunique, ix.indexprs, ix.indpred,
- a.attname, a.attnum, ix.indkey%s
+ a.attname, a.attnum, c.conrelid, ix.indkey%s
FROM
pg_class t
join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid=ix.indexrelid
+ join pg_class i on i.oid = ix.indexrelid
left outer join
pg_attribute a
- on t.oid=a.attrelid and %s
+ on t.oid = a.attrelid and %s
+ left outer join
+ pg_constraint c
+ on (ix.indrelid = c.conrelid and
+ ix.indexrelid = c.conindid and
+ c.contype in ('p', 'u', 'x'))
WHERE
- t.relkind = 'r'
+ t.relkind IN ('r', 'v', 'f', 'm')
and t.oid = :table_oid
and ix.indisprimary = 'f'
ORDER BY
@@ -2469,7 +2528,7 @@ class PGDialect(default.DefaultDialect):
sv_idx_name = None
for row in c.fetchall():
- idx_name, unique, expr, prd, col, col_num, idx_key = row
+ idx_name, unique, expr, prd, col, col_num, conrelid, idx_key = row
if expr:
if idx_name != sv_idx_name:
@@ -2486,18 +2545,27 @@ class PGDialect(default.DefaultDialect):
% idx_name)
sv_idx_name = idx_name
+ has_idx = idx_name in indexes
index = indexes[idx_name]
if col is not None:
index['cols'][col_num] = col
- index['key'] = [int(k.strip()) for k in idx_key.split()]
- index['unique'] = unique
-
- return [
- {'name': name,
- 'unique': idx['unique'],
- 'column_names': [idx['cols'][i] for i in idx['key']]}
- for name, idx in indexes.items()
- ]
+ if not has_idx:
+ index['key'] = [int(k.strip()) for k in idx_key.split()]
+ index['unique'] = unique
+ if conrelid is not None:
+ index['duplicates_constraint'] = idx_name
+
+ result = []
+ for name, idx in indexes.items():
+ entry = {
+ 'name': name,
+ 'unique': idx['unique'],
+ 'column_names': [idx['cols'][i] for i in idx['key']]
+ }
+ if 'duplicates_constraint' in idx:
+ entry['duplicates_constraint'] = idx['duplicates_constraint']
+ result.append(entry)
+ return result
@reflection.cache
def get_unique_constraints(self, connection, table_name,
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index e6450c97f..1a2a1ffe4 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -32,10 +32,25 @@ psycopg2-specific keyword arguments which are accepted by
way of enabling this mode on a per-execution basis.
* ``use_native_unicode``: Enable the usage of Psycopg2 "native unicode" mode
per connection. True by default.
+
+ .. seealso::
+
+ :ref:`psycopg2_disable_native_unicode`
+
* ``isolation_level``: This option, available for all PostgreSQL dialects,
includes the ``AUTOCOMMIT`` isolation level when using the psycopg2
- dialect. See :ref:`psycopg2_isolation_level`.
+ dialect.
+
+ .. seealso::
+
+ :ref:`psycopg2_isolation_level`
+
+* ``client_encoding``: sets the client encoding in a libpq-agnostic way,
+ using psycopg2's ``set_client_encoding()`` method.
+
+ .. seealso::
+ :ref:`psycopg2_unicode`
Unix Domain Connections
------------------------
@@ -75,8 +90,10 @@ The following DBAPI-specific options are respected when used with
If ``None`` or not set, the ``server_side_cursors`` option of the
:class:`.Engine` is used.
-Unicode
--------
+.. _psycopg2_unicode:
+
+Unicode with Psycopg2
+----------------------
By default, the psycopg2 driver uses the ``psycopg2.extensions.UNICODE``
extension, such that the DBAPI receives and returns all strings as Python
@@ -84,27 +101,51 @@ Unicode objects directly - SQLAlchemy passes these values through without
change. Psycopg2 here will encode/decode string values based on the
current "client encoding" setting; by default this is the value in
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
-Typically, this can be changed to ``utf-8``, as a more useful default::
+Typically, this can be changed to ``utf8``, as a more useful default::
+
+ # postgresql.conf file
- #client_encoding = sql_ascii # actually, defaults to database
+ # client_encoding = sql_ascii # actually, defaults to database
# encoding
client_encoding = utf8
A second way to affect the client encoding is to set it within Psycopg2
-locally. SQLAlchemy will call psycopg2's ``set_client_encoding()``
-method (see:
-http://initd.org/psycopg/docs/connection.html#connection.set_client_encoding)
+locally. SQLAlchemy will call psycopg2's
+:meth:`psycopg2:connection.set_client_encoding` method
on all new connections based on the value passed to
:func:`.create_engine` using the ``client_encoding`` parameter::
+ # set_client_encoding() setting;
+ # works for *all* Postgresql versions
engine = create_engine("postgresql://user:pass@host/dbname",
client_encoding='utf8')
This overrides the encoding specified in the Postgresql client configuration.
+When using the parameter in this way, the psycopg2 driver emits
+``SET client_encoding TO 'utf8'`` on the connection explicitly, and works
+in all Postgresql versions.
+
+Note that the ``client_encoding`` setting as passed to :func:`.create_engine`
+is **not the same** as the more recently added ``client_encoding`` parameter
+now supported by libpq directly. This is enabled when ``client_encoding``
+is passed directly to ``psycopg2.connect()``, and from SQLAlchemy is passed
+using the :paramref:`.create_engine.connect_args` parameter::
+
+ # libpq direct parameter setting;
+ # only works for Postgresql **9.1 and above**
+ engine = create_engine("postgresql://user:pass@host/dbname",
+ connect_args={'client_encoding': 'utf8'})
+
+ # using the query string is equivalent
+ engine = create_engine("postgresql://user:pass@host/dbname?client_encoding=utf8")
+
+The above parameter was only added to libpq as of version 9.1 of Postgresql,
+so using the previous method is better for cross-version support.
+
+.. _psycopg2_disable_native_unicode:
-.. versionadded:: 0.7.3
- The psycopg2-specific ``client_encoding`` parameter to
- :func:`.create_engine`.
+Disabling Native Unicode
+^^^^^^^^^^^^^^^^^^^^^^^^
SQLAlchemy can also be instructed to skip the usage of the psycopg2
``UNICODE`` extension and to instead utilize its own unicode encode/decode
@@ -116,8 +157,7 @@ in and coerce from bytes on the way back,
using the value of the :func:`.create_engine` ``encoding`` parameter, which
defaults to ``utf-8``.
SQLAlchemy's own unicode encode/decode functionality is steadily becoming
-obsolete as more DBAPIs support unicode fully along with the approach of
-Python 3; in modern usage psycopg2 should be relied upon to handle unicode.
+obsolete as most DBAPIs now support unicode fully.
Transactions
------------
@@ -512,12 +552,14 @@ class PGDialect_psycopg2(PGDialect):
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
# check the "closed" flag. this might not be
- # present on old psycopg2 versions
+ # present on old psycopg2 versions. Also,
+ # this flag doesn't actually help in a lot of disconnect
+ # situations, so don't rely on it.
if getattr(connection, 'closed', False):
return True
- # legacy checks based on strings. the "closed" check
- # above most likely obviates the need for any of these.
+ # checks based on strings. in the case that .closed
+ # didn't cut it, fall back onto these.
str_e = str(e).partition("\n")[0]
for msg in [
# these error messages from libpq: interfaces/libpq/fe-misc.c
@@ -534,8 +576,10 @@ class PGDialect_psycopg2(PGDialect):
# not sure where this path is originally from, it may
# be obsolete. It really says "losed", not "closed".
'losed the connection unexpectedly',
- # this can occur in newer SSL
- 'connection has been closed unexpectedly'
+ # these can occur in newer SSL
+ 'connection has been closed unexpectedly',
+ 'SSL SYSCALL error: Bad file descriptor',
+ 'SSL SYSCALL error: EOF detected',
]:
idx = str_e.find(msg)
if idx >= 0 and '"' not in str_e[:idx]:
diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py
index 0eceaa537..a53d53e9d 100644
--- a/lib/sqlalchemy/dialects/sqlite/__init__.py
+++ b/lib/sqlalchemy/dialects/sqlite/__init__.py
@@ -5,7 +5,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy.dialects.sqlite import base, pysqlite
+from sqlalchemy.dialects.sqlite import base, pysqlite, pysqlcipher
# default dialect
base.dialect = pysqlite.dialect
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index af793d275..335b35c94 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -713,10 +713,12 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
return self.execution_options.get("sqlite_raw_colnames", False)
def _translate_colname(self, colname):
- # adjust for dotted column names. SQLite in the case of UNION may
- # store col names as "tablename.colname" in cursor.description
+ # adjust for dotted column names. SQLite
+ # in the case of UNION may store col names as
+ # "tablename.colname", or if using an attached database,
+ # "database.tablename.colname", in cursor.description
if not self._preserve_raw_colnames and "." in colname:
- return colname.split(".")[1], colname
+ return colname.split(".")[-1], colname
else:
return colname, None
@@ -829,20 +831,26 @@ class SQLiteDialect(default.DefaultDialect):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
master = '%s.sqlite_master' % qschema
- s = ("SELECT name FROM %s "
- "WHERE type='table' ORDER BY name") % (master,)
- rs = connection.execute(s)
else:
- try:
- s = ("SELECT name FROM "
- " (SELECT * FROM sqlite_master UNION ALL "
- " SELECT * FROM sqlite_temp_master) "
- "WHERE type='table' ORDER BY name")
- rs = connection.execute(s)
- except exc.DBAPIError:
- s = ("SELECT name FROM sqlite_master "
- "WHERE type='table' ORDER BY name")
- rs = connection.execute(s)
+ master = "sqlite_master"
+ s = ("SELECT name FROM %s "
+ "WHERE type='table' ORDER BY name") % (master,)
+ rs = connection.execute(s)
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ s = "SELECT name FROM sqlite_temp_master "\
+ "WHERE type='table' ORDER BY name "
+ rs = connection.execute(s)
+
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_temp_view_names(self, connection, **kw):
+ s = "SELECT name FROM sqlite_temp_master "\
+ "WHERE type='view' ORDER BY name "
+ rs = connection.execute(s)
return [row[0] for row in rs]
@@ -869,20 +877,11 @@ class SQLiteDialect(default.DefaultDialect):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
master = '%s.sqlite_master' % qschema
- s = ("SELECT name FROM %s "
- "WHERE type='view' ORDER BY name") % (master,)
- rs = connection.execute(s)
else:
- try:
- s = ("SELECT name FROM "
- " (SELECT * FROM sqlite_master UNION ALL "
- " SELECT * FROM sqlite_temp_master) "
- "WHERE type='view' ORDER BY name")
- rs = connection.execute(s)
- except exc.DBAPIError:
- s = ("SELECT name FROM sqlite_master "
- "WHERE type='view' ORDER BY name")
- rs = connection.execute(s)
+ master = "sqlite_master"
+ s = ("SELECT name FROM %s "
+ "WHERE type='view' ORDER BY name") % (master,)
+ rs = connection.execute(s)
return [row[0] for row in rs]
@@ -1097,16 +1096,24 @@ class SQLiteDialect(default.DefaultDialect):
@reflection.cache
def get_unique_constraints(self, connection, table_name,
schema=None, **kw):
- UNIQUE_SQL = """
- SELECT sql
- FROM
- sqlite_master
- WHERE
- type='table' AND
- name=:table_name
- """
- c = connection.execute(UNIQUE_SQL, table_name=table_name)
- table_data = c.fetchone()[0]
+ try:
+ s = ("SELECT sql FROM "
+ " (SELECT * FROM sqlite_master UNION ALL "
+ " SELECT * FROM sqlite_temp_master) "
+ "WHERE name = '%s' "
+ "AND type = 'table'") % table_name
+ rs = connection.execute(s)
+ except exc.DBAPIError:
+ s = ("SELECT sql FROM sqlite_master WHERE name = '%s' "
+ "AND type = 'table'") % table_name
+ rs = connection.execute(s)
+ row = rs.fetchone()
+ if row is None:
+ # sqlite won't return the schema for the sqlite_master or
+ # sqlite_temp_master tables from this query. These tables
+ # don't have any unique constraints anyway.
+ return []
+ table_data = row[0]
UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)'
return [
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
new file mode 100644
index 000000000..3c55a1de7
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
@@ -0,0 +1,116 @@
+# sqlite/pysqlcipher.py
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: sqlite+pysqlcipher
+ :name: pysqlcipher
+ :dbapi: pysqlcipher
+ :connectstring: sqlite+pysqlcipher://:passphrase/file_path[?kdf_iter=<iter>]
+ :url: https://pypi.python.org/pypi/pysqlcipher
+
+ ``pysqlcipher`` is a fork of the standard ``pysqlite`` driver to make
+ use of the `SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend.
+
+ .. versionadded:: 0.9.9
+
+Driver
+------
+
+The driver here is the `pysqlcipher <https://pypi.python.org/pypi/pysqlcipher>`_
+driver, which makes use of the SQLCipher engine. This system essentially
+introduces new PRAGMA commands to SQLite which allows the setting of a
+passphrase and other encryption parameters, allowing the database
+file to be encrypted.
+
+Connect Strings
+---------------
+
+The format of the connect string is in every way the same as that
+of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
+"password" field is now accepted, which should contain a passphrase::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
+
+For an absolute file path, two leading slashes should be used for the
+database name::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
+
+A selection of additional encryption-related pragmas supported by SQLCipher
+as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
+in the query string, and will result in that PRAGMA being called for each
+new connection. Currently, ``cipher``, ``kdf_iter``
+``cipher_page_size`` and ``cipher_use_hmac`` are supported::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
+
+
+Pooling Behavior
+----------------
+
+The driver makes a change to the default pool behavior of pysqlite
+as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver
+has been observed to be significantly slower on connection than the
+pysqlite driver, most likely due to the encryption overhead, so the
+dialect here defaults to using the :class:`.SingletonThreadPool`
+implementation,
+instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool
+implementation is entirely configurable using the
+:paramref:`.create_engine.poolclass` parameter; the :class:`.StaticPool` may
+be more feasible for single-threaded use, or :class:`.NullPool` may be used
+to prevent unencrypted connections from being held open for long periods of
+time, at the expense of slower startup time for new connections.
+
+
+"""
+from __future__ import absolute_import
+from .pysqlite import SQLiteDialect_pysqlite
+from ...engine import url as _url
+from ... import pool
+
+
+class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
+ driver = 'pysqlcipher'
+
+ pragmas = ('kdf_iter', 'cipher', 'cipher_page_size', 'cipher_use_hmac')
+
+ @classmethod
+ def dbapi(cls):
+ from pysqlcipher import dbapi2 as sqlcipher
+ return sqlcipher
+
+ @classmethod
+ def get_pool_class(cls, url):
+ return pool.SingletonThreadPool
+
+ def connect(self, *cargs, **cparams):
+ passphrase = cparams.pop('passphrase', '')
+
+ pragmas = dict(
+ (key, cparams.pop(key, None)) for key in
+ self.pragmas
+ )
+
+ conn = super(SQLiteDialect_pysqlcipher, self).\
+ connect(*cargs, **cparams)
+ conn.execute('pragma key="%s"' % passphrase)
+ for prag, value in pragmas.items():
+ if value is not None:
+ conn.execute('pragma %s=%s' % (prag, value))
+
+ return conn
+
+ def create_connect_args(self, url):
+ super_url = _url.URL(
+ url.drivername, username=url.username,
+ host=url.host, database=url.database, query=url.query)
+ c_args, opts = super(SQLiteDialect_pysqlcipher, self).\
+ create_connect_args(super_url)
+ opts['passphrase'] = url.password
+ return c_args, opts
+
+dialect = SQLiteDialect_pysqlcipher
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index d2cc8890f..dd82be1d1 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -45,7 +45,7 @@ class Connection(Connectable):
"""
def __init__(self, engine, connection=None, close_with_result=False,
- _branch=False, _execution_options=None,
+ _branch_from=None, _execution_options=None,
_dispatch=None,
_has_events=None):
"""Construct a new Connection.
@@ -57,48 +57,80 @@ class Connection(Connectable):
"""
self.engine = engine
self.dialect = engine.dialect
- self.__connection = connection or engine.raw_connection()
- self.__transaction = None
- self.should_close_with_result = close_with_result
- self.__savepoint_seq = 0
- self.__branch = _branch
- self.__invalid = False
- self.__can_reconnect = True
- if _dispatch:
+ self.__branch_from = _branch_from
+ self.__branch = _branch_from is not None
+
+ if _branch_from:
+ self.__connection = connection
+ self._execution_options = _execution_options
+ self._echo = _branch_from._echo
+ self.should_close_with_result = False
self.dispatch = _dispatch
- elif _has_events is None:
- # if _has_events is sent explicitly as False,
- # then don't join the dispatch of the engine; we don't
- # want to handle any of the engine's events in that case.
- self.dispatch = self.dispatch._join(engine.dispatch)
- self._has_events = _has_events or (
- _has_events is None and engine._has_events)
-
- self._echo = self.engine._should_log_info()
- if _execution_options:
- self._execution_options =\
- engine._execution_options.union(_execution_options)
+ self._has_events = _branch_from._has_events
else:
+ self.__connection = connection \
+ if connection is not None else engine.raw_connection()
+ self.__transaction = None
+ self.__savepoint_seq = 0
+ self.should_close_with_result = close_with_result
+ self.__invalid = False
+ self.__can_reconnect = True
+ self._echo = self.engine._should_log_info()
+
+ if _has_events is None:
+ # if _has_events is sent explicitly as False,
+ # then don't join the dispatch of the engine; we don't
+ # want to handle any of the engine's events in that case.
+ self.dispatch = self.dispatch._join(engine.dispatch)
+ self._has_events = _has_events or (
+ _has_events is None and engine._has_events)
+
+ assert not _execution_options
self._execution_options = engine._execution_options
if self._has_events or self.engine._has_events:
- self.dispatch.engine_connect(self, _branch)
+ self.dispatch.engine_connect(self, self.__branch)
def _branch(self):
"""Return a new Connection which references this Connection's
engine and connection; but does not have close_with_result enabled,
and also whose close() method does nothing.
- This is used to execute "sub" statements within a single execution,
- usually an INSERT statement.
+ The Core uses this very sparingly, only in the case of
+ custom SQL default functions that are to be INSERTed as the
+ primary key of a row where we need to get the value back, so we have
+ to invoke it distinctly - this is a very uncommon case.
+
+ Userland code accesses _branch() when the connect() or
+ contextual_connect() methods are called. The branched connection
+ acts as much as possible like the parent, except that it stays
+ connected when a close() event occurs.
+
"""
+ if self.__branch_from:
+ return self.__branch_from._branch()
+ else:
+ return self.engine._connection_cls(
+ self.engine,
+ self.__connection,
+ _branch_from=self,
+ _execution_options=self._execution_options,
+ _has_events=self._has_events,
+ _dispatch=self.dispatch)
+
+ @property
+ def _root(self):
+ """return the 'root' connection.
- return self.engine._connection_cls(
- self.engine,
- self.__connection,
- _branch=True,
- _has_events=self._has_events,
- _dispatch=self.dispatch)
+ Returns 'self' if this connection is not a branch, else
+ returns the root connection from which we ultimately branched.
+
+ """
+
+ if self.__branch_from:
+ return self.__branch_from
+ else:
+ return self
def _clone(self):
"""Create a shallow copy of this Connection.
@@ -224,7 +256,7 @@ class Connection(Connectable):
def invalidated(self):
"""Return True if this connection was invalidated."""
- return self.__invalid
+ return self._root.__invalid
@property
def connection(self):
@@ -236,6 +268,9 @@ class Connection(Connectable):
return self._revalidate_connection()
def _revalidate_connection(self):
+ if self.__branch_from:
+ return self.__branch_from._revalidate_connection()
+
if self.__can_reconnect and self.__invalid:
if self.__transaction is not None:
raise exc.InvalidRequestError(
@@ -343,16 +378,17 @@ class Connection(Connectable):
:ref:`pool_connection_invalidation`
"""
+
if self.invalidated:
return
if self.closed:
raise exc.ResourceClosedError("This Connection is closed")
- if self._connection_is_valid:
- self.__connection.invalidate(exception)
- del self.__connection
- self.__invalid = True
+ if self._root._connection_is_valid:
+ self._root.__connection.invalidate(exception)
+ del self._root.__connection
+ self._root.__invalid = True
def detach(self):
"""Detach the underlying DB-API connection from its connection pool.
@@ -415,6 +451,8 @@ class Connection(Connectable):
:class:`.Engine`.
"""
+ if self.__branch_from:
+ return self.__branch_from.begin()
if self.__transaction is None:
self.__transaction = RootTransaction(self)
@@ -436,6 +474,9 @@ class Connection(Connectable):
See also :meth:`.Connection.begin`,
:meth:`.Connection.begin_twophase`.
"""
+ if self.__branch_from:
+ return self.__branch_from.begin_nested()
+
if self.__transaction is None:
self.__transaction = RootTransaction(self)
else:
@@ -459,6 +500,9 @@ class Connection(Connectable):
"""
+ if self.__branch_from:
+ return self.__branch_from.begin_twophase(xid=xid)
+
if self.__transaction is not None:
raise exc.InvalidRequestError(
"Cannot start a two phase transaction when a transaction "
@@ -479,10 +523,11 @@ class Connection(Connectable):
def in_transaction(self):
"""Return True if a transaction is in progress."""
-
- return self.__transaction is not None
+ return self._root.__transaction is not None
def _begin_impl(self, transaction):
+ assert not self.__branch_from
+
if self._echo:
self.engine.logger.info("BEGIN (implicit)")
@@ -497,6 +542,8 @@ class Connection(Connectable):
self._handle_dbapi_exception(e, None, None, None, None)
def _rollback_impl(self):
+ assert not self.__branch_from
+
if self._has_events or self.engine._has_events:
self.dispatch.rollback(self)
@@ -516,6 +563,8 @@ class Connection(Connectable):
self.__transaction = None
def _commit_impl(self, autocommit=False):
+ assert not self.__branch_from
+
if self._has_events or self.engine._has_events:
self.dispatch.commit(self)
@@ -532,6 +581,8 @@ class Connection(Connectable):
self.__transaction = None
def _savepoint_impl(self, name=None):
+ assert not self.__branch_from
+
if self._has_events or self.engine._has_events:
self.dispatch.savepoint(self, name)
@@ -543,6 +594,8 @@ class Connection(Connectable):
return name
def _rollback_to_savepoint_impl(self, name, context):
+ assert not self.__branch_from
+
if self._has_events or self.engine._has_events:
self.dispatch.rollback_savepoint(self, name, context)
@@ -551,6 +604,8 @@ class Connection(Connectable):
self.__transaction = context
def _release_savepoint_impl(self, name, context):
+ assert not self.__branch_from
+
if self._has_events or self.engine._has_events:
self.dispatch.release_savepoint(self, name, context)
@@ -559,6 +614,8 @@ class Connection(Connectable):
self.__transaction = context
def _begin_twophase_impl(self, transaction):
+ assert not self.__branch_from
+
if self._echo:
self.engine.logger.info("BEGIN TWOPHASE (implicit)")
if self._has_events or self.engine._has_events:
@@ -571,6 +628,8 @@ class Connection(Connectable):
self.connection._reset_agent = transaction
def _prepare_twophase_impl(self, xid):
+ assert not self.__branch_from
+
if self._has_events or self.engine._has_events:
self.dispatch.prepare_twophase(self, xid)
@@ -579,6 +638,8 @@ class Connection(Connectable):
self.engine.dialect.do_prepare_twophase(self, xid)
def _rollback_twophase_impl(self, xid, is_prepared):
+ assert not self.__branch_from
+
if self._has_events or self.engine._has_events:
self.dispatch.rollback_twophase(self, xid, is_prepared)
@@ -595,6 +656,8 @@ class Connection(Connectable):
self.__transaction = None
def _commit_twophase_impl(self, xid, is_prepared):
+ assert not self.__branch_from
+
if self._has_events or self.engine._has_events:
self.dispatch.commit_twophase(self, xid, is_prepared)
@@ -610,8 +673,8 @@ class Connection(Connectable):
self.__transaction = None
def _autorollback(self):
- if not self.in_transaction():
- self._rollback_impl()
+ if not self._root.in_transaction():
+ self._root._rollback_impl()
def close(self):
"""Close this :class:`.Connection`.
@@ -632,13 +695,21 @@ class Connection(Connectable):
and will allow no further operations.
"""
+ if self.__branch_from:
+ try:
+ del self.__connection
+ except AttributeError:
+ pass
+ finally:
+ self.__can_reconnect = False
+ return
try:
conn = self.__connection
except AttributeError:
pass
else:
- if not self.__branch:
- conn.close()
+
+ conn.close()
if conn._reset_agent is self.__transaction:
conn._reset_agent = None
@@ -993,8 +1064,8 @@ class Connection(Connectable):
result.rowcount
result.close(_autoclose_connection=False)
- if self.__transaction is None and context.should_autocommit:
- self._commit_impl(autocommit=True)
+ if context.should_autocommit and self._root.__transaction is None:
+ self._root._commit_impl(autocommit=True)
if result.closed and self.should_close_with_result:
self.close()
@@ -1055,8 +1126,6 @@ class Connection(Connectable):
"""
try:
cursor.close()
- except (SystemExit, KeyboardInterrupt):
- raise
except Exception:
# log the error through the connection pool's logger.
self.engine.pool.logger.error(
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 71df29cac..0ad2efae0 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -308,7 +308,15 @@ class Dialect(object):
def get_table_names(self, connection, schema=None, **kw):
"""Return a list of table names for `schema`."""
- raise NotImplementedError
+ raise NotImplementedError()
+
+ def get_temp_table_names(self, connection, schema=None, **kw):
+ """Return a list of temporary table names on the given connection,
+ if supported by the underlying backend.
+
+ """
+
+ raise NotImplementedError()
def get_view_names(self, connection, schema=None, **kw):
"""Return a list of all view names available in the database.
@@ -319,6 +327,14 @@ class Dialect(object):
raise NotImplementedError()
+ def get_temp_view_names(self, connection, schema=None, **kw):
+ """Return a list of temporary view names on the given connection,
+ if supported by the underlying backend.
+
+ """
+
+ raise NotImplementedError()
+
def get_view_definition(self, connection, view_name, schema=None, **kw):
"""Return view definition.
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
index cf1f2d3dd..2a1def86a 100644
--- a/lib/sqlalchemy/engine/reflection.py
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -201,6 +201,30 @@ class Inspector(object):
tnames = list(topological.sort(tuples, tnames))
return tnames
+ def get_temp_table_names(self):
+ """return a list of temporary table names for the current bind.
+
+ This method is unsupported by most dialects; currently
+ only SQLite implements it.
+
+ .. versionadded:: 1.0.0
+
+ """
+ return self.dialect.get_temp_table_names(
+ self.bind, info_cache=self.info_cache)
+
+ def get_temp_view_names(self):
+ """return a list of temporary view names for the current bind.
+
+ This method is unsupported by most dialects; currently
+ only SQLite implements it.
+
+ .. versionadded:: 1.0.0
+
+ """
+ return self.dialect.get_temp_view_names(
+ self.bind, info_cache=self.info_cache)
+
def get_table_options(self, table_name, schema=None, **kw):
"""Return a dictionary of options specified when the table of the
given name was created.
@@ -465,55 +489,87 @@ class Inspector(object):
for col_d in self.get_columns(
table_name, schema, **table.dialect_kwargs):
found_table = True
- orig_name = col_d['name']
- table.dispatch.column_reflect(self, table, col_d)
+ self._reflect_column(
+ table, col_d, include_columns,
+ exclude_columns, cols_by_orig_name)
- name = col_d['name']
- if include_columns and name not in include_columns:
- continue
- if exclude_columns and name in exclude_columns:
- continue
+ if not found_table:
+ raise exc.NoSuchTableError(table.name)
- coltype = col_d['type']
+ self._reflect_pk(
+ table_name, schema, table, cols_by_orig_name, exclude_columns)
- col_kw = dict(
- (k, col_d[k])
- for k in ['nullable', 'autoincrement', 'quote', 'info', 'key']
- if k in col_d
- )
+ self._reflect_fk(
+ table_name, schema, table, cols_by_orig_name,
+ exclude_columns, reflection_options)
- colargs = []
- if col_d.get('default') is not None:
- # the "default" value is assumed to be a literal SQL
- # expression, so is wrapped in text() so that no quoting
- # occurs on re-issuance.
- colargs.append(
- sa_schema.DefaultClause(
- sql.text(col_d['default']), _reflected=True
- )
- )
+ self._reflect_indexes(
+ table_name, schema, table, cols_by_orig_name,
+ include_columns, exclude_columns, reflection_options)
- if 'sequence' in col_d:
- # TODO: mssql and sybase are using this.
- seq = col_d['sequence']
- sequence = sa_schema.Sequence(seq['name'], 1, 1)
- if 'start' in seq:
- sequence.start = seq['start']
- if 'increment' in seq:
- sequence.increment = seq['increment']
- colargs.append(sequence)
+ self._reflect_unique_constraints(
+ table_name, schema, table, cols_by_orig_name,
+ include_columns, exclude_columns, reflection_options)
- cols_by_orig_name[orig_name] = col = \
- sa_schema.Column(name, coltype, *colargs, **col_kw)
+ def _reflect_column(
+ self, table, col_d, include_columns,
+ exclude_columns, cols_by_orig_name):
- if col.key in table.primary_key:
- col.primary_key = True
- table.append_column(col)
+ orig_name = col_d['name']
- if not found_table:
- raise exc.NoSuchTableError(table.name)
+ table.dispatch.column_reflect(self, table, col_d)
+
+ # fetch name again as column_reflect is allowed to
+ # change it
+ name = col_d['name']
+ if (include_columns and name not in include_columns) \
+ or (exclude_columns and name in exclude_columns):
+ return
+
+ coltype = col_d['type']
+
+ col_kw = dict(
+ (k, col_d[k])
+ for k in ['nullable', 'autoincrement', 'quote', 'info', 'key']
+ if k in col_d
+ )
+
+ colargs = []
+ if col_d.get('default') is not None:
+ # the "default" value is assumed to be a literal SQL
+ # expression, so is wrapped in text() so that no quoting
+ # occurs on re-issuance.
+ colargs.append(
+ sa_schema.DefaultClause(
+ sql.text(col_d['default']), _reflected=True
+ )
+ )
+ if 'sequence' in col_d:
+ self._reflect_col_sequence(col_d, colargs)
+
+ cols_by_orig_name[orig_name] = col = \
+ sa_schema.Column(name, coltype, *colargs, **col_kw)
+
+ if col.key in table.primary_key:
+ col.primary_key = True
+ table.append_column(col)
+
+ def _reflect_col_sequence(self, col_d, colargs):
+ if 'sequence' in col_d:
+ # TODO: mssql and sybase are using this.
+ seq = col_d['sequence']
+ sequence = sa_schema.Sequence(seq['name'], 1, 1)
+ if 'start' in seq:
+ sequence.start = seq['start']
+ if 'increment' in seq:
+ sequence.increment = seq['increment']
+ colargs.append(sequence)
+
+ def _reflect_pk(
+ self, table_name, schema, table,
+ cols_by_orig_name, exclude_columns):
pk_cons = self.get_pk_constraint(
table_name, schema, **table.dialect_kwargs)
if pk_cons:
@@ -530,6 +586,9 @@ class Inspector(object):
# its column collection
table.primary_key._reload(pk_cols)
+ def _reflect_fk(
+ self, table_name, schema, table, cols_by_orig_name,
+ exclude_columns, reflection_options):
fkeys = self.get_foreign_keys(
table_name, schema, **table.dialect_kwargs)
for fkey_d in fkeys:
@@ -572,6 +631,10 @@ class Inspector(object):
sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
conname, link_to_name=True,
**options))
+
+ def _reflect_indexes(
+ self, table_name, schema, table, cols_by_orig_name,
+ include_columns, exclude_columns, reflection_options):
# Indexes
indexes = self.get_indexes(table_name, schema)
for index_d in indexes:
@@ -579,12 +642,15 @@ class Inspector(object):
columns = index_d['column_names']
unique = index_d['unique']
flavor = index_d.get('type', 'index')
+ duplicates = index_d.get('duplicates_constraint')
if include_columns and \
not set(columns).issubset(include_columns):
util.warn(
"Omitting %s key for (%s), key covers omitted columns." %
(flavor, ', '.join(columns)))
continue
+ if duplicates:
+ continue
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
idx_cols = []
@@ -602,3 +668,43 @@ class Inspector(object):
idx_cols.append(idx_col)
sa_schema.Index(name, *idx_cols, **dict(unique=unique))
+
+ def _reflect_unique_constraints(
+ self, table_name, schema, table, cols_by_orig_name,
+ include_columns, exclude_columns, reflection_options):
+
+ # Unique Constraints
+ try:
+ constraints = self.get_unique_constraints(table_name, schema)
+ except NotImplementedError:
+ # optional dialect feature
+ return
+
+ for const_d in constraints:
+ conname = const_d['name']
+ columns = const_d['column_names']
+ duplicates = const_d.get('duplicates_index')
+ if include_columns and \
+ not set(columns).issubset(include_columns):
+ util.warn(
+ "Omitting unique constraint key for (%s), "
+ "key covers omitted columns." %
+ ', '.join(columns))
+ continue
+ if duplicates:
+ continue
+ # look for columns by orig name in cols_by_orig_name,
+ # but support columns that are in-Python only as fallback
+ constrained_cols = []
+ for c in columns:
+ try:
+ constrained_col = cols_by_orig_name[c] \
+ if c in cols_by_orig_name else table.c[c]
+ except KeyError:
+ util.warn(
+ "unique constraint key '%s' was not located in "
+ "columns for table '%s'" % (c, table_name))
+ else:
+ constrained_cols.append(constrained_col)
+ table.append_constraint(
+ sa_schema.UniqueConstraint(*constrained_cols, name=conname))
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index 38206be89..398ef8df6 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -162,6 +162,7 @@ class DefaultEngineStrategy(EngineStrategy):
def first_connect(dbapi_connection, connection_record):
c = base.Connection(engine, connection=dbapi_connection,
_has_events=False)
+ c._execution_options = util.immutabledict()
dialect.initialize(c)
event.listen(pool, 'first_connect', first_connect, once=True)
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py
index dba1063cf..be2a82208 100644
--- a/lib/sqlalchemy/event/attr.py
+++ b/lib/sqlalchemy/event/attr.py
@@ -319,14 +319,12 @@ class _ListenerCollection(RefCollection, _CompoundListener):
registry._stored_in_collection_multi(self, other, to_associate)
def insert(self, event_key, propagate):
- if event_key._listen_fn not in self.listeners:
- event_key.prepend_to_list(self, self.listeners)
+ if event_key.prepend_to_list(self, self.listeners):
if propagate:
self.propagate.add(event_key._listen_fn)
def append(self, event_key, propagate):
- if event_key._listen_fn not in self.listeners:
- event_key.append_to_list(self, self.listeners)
+ if event_key.append_to_list(self, self.listeners):
if propagate:
self.propagate.add(event_key._listen_fn)
diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py
index ba2f671a3..5b422c401 100644
--- a/lib/sqlalchemy/event/registry.py
+++ b/lib/sqlalchemy/event/registry.py
@@ -71,13 +71,15 @@ def _stored_in_collection(event_key, owner):
listen_ref = weakref.ref(event_key._listen_fn)
if owner_ref in dispatch_reg:
- assert dispatch_reg[owner_ref] == listen_ref
- else:
- dispatch_reg[owner_ref] = listen_ref
+ return False
+
+ dispatch_reg[owner_ref] = listen_ref
listener_to_key = _collection_to_key[owner_ref]
listener_to_key[listen_ref] = key
+ return True
+
def _removed_from_collection(event_key, owner):
key = event_key._key
@@ -180,6 +182,17 @@ class _EventKey(object):
def listen(self, *args, **kw):
once = kw.pop("once", False)
+ named = kw.pop("named", False)
+
+ target, identifier, fn = \
+ self.dispatch_target, self.identifier, self._listen_fn
+
+ dispatch_descriptor = getattr(target.dispatch, identifier)
+
+ adjusted_fn = dispatch_descriptor._adjust_fn_spec(fn, named)
+
+ self = self.with_wrapper(adjusted_fn)
+
if once:
self.with_wrapper(
util.only_once(self._listen_fn)).listen(*args, **kw)
@@ -215,9 +228,6 @@ class _EventKey(object):
dispatch_descriptor = getattr(target.dispatch, identifier)
- fn = dispatch_descriptor._adjust_fn_spec(fn, named)
- self = self.with_wrapper(fn)
-
if insert:
dispatch_descriptor.\
for_modify(target.dispatch).insert(self, propagate)
@@ -229,18 +239,20 @@ class _EventKey(object):
def _listen_fn(self):
return self.fn_wrap or self.fn
- def append_value_to_list(self, owner, list_, value):
- _stored_in_collection(self, owner)
- list_.append(value)
-
def append_to_list(self, owner, list_):
- _stored_in_collection(self, owner)
- list_.append(self._listen_fn)
+ if _stored_in_collection(self, owner):
+ list_.append(self._listen_fn)
+ return True
+ else:
+ return False
def remove_from_list(self, owner, list_):
_removed_from_collection(self, owner)
list_.remove(self._listen_fn)
def prepend_to_list(self, owner, list_):
- _stored_in_collection(self, owner)
- list_.appendleft(self._listen_fn)
+ if _stored_in_collection(self, owner):
+ list_.appendleft(self._listen_fn)
+ return True
+ else:
+ return False
diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py
index 1ecec51b6..b4f057b0a 100644
--- a/lib/sqlalchemy/events.py
+++ b/lib/sqlalchemy/events.py
@@ -338,7 +338,7 @@ class PoolEvents(event.Events):
"""
- def reset(self, dbapi_connnection, connection_record):
+ def reset(self, dbapi_connection, connection_record):
"""Called before the "reset" action occurs for a pooled connection.
This event represents
@@ -470,7 +470,8 @@ class ConnectionEvents(event.Events):
@classmethod
def _listen(cls, event_key, retval=False):
target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, event_key.fn
+ event_key.dispatch_target, event_key.identifier, \
+ event_key._listen_fn
target._has_events = True
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
index a82bae33f..3271d09d4 100644
--- a/lib/sqlalchemy/exc.py
+++ b/lib/sqlalchemy/exc.py
@@ -238,14 +238,16 @@ class StatementError(SQLAlchemyError):
def __str__(self):
from sqlalchemy.sql import util
- params_repr = util._repr_params(self.params, 10)
+ details = [SQLAlchemyError.__str__(self)]
+ if self.statement:
+ details.append("[SQL: %r]" % self.statement)
+ if self.params:
+ params_repr = util._repr_params(self.params, 10)
+ details.append("[parameters: %r]" % params_repr)
return ' '.join([
"(%s)" % det for det in self.detail
- ] + [
- SQLAlchemyError.__str__(self),
- repr(self.statement), repr(params_repr)
- ])
+ ] + details)
def __unicode__(self):
return self.__str__()
@@ -280,17 +282,19 @@ class DBAPIError(StatementError):
connection_invalidated=False):
# Don't ever wrap these, just return them directly as if
# DBAPIError didn't exist.
- if isinstance(orig, (KeyboardInterrupt, SystemExit, DontWrapMixin)):
+ if (isinstance(orig, BaseException) and
+ not isinstance(orig, Exception)) or \
+ isinstance(orig, DontWrapMixin):
return orig
if orig is not None:
# not a DBAPI error, statement is present.
# raise a StatementError
if not isinstance(orig, dbapi_base_err) and statement:
- msg = traceback.format_exception_only(
- orig.__class__, orig)[-1].strip()
return StatementError(
- "%s (original cause: %s)" % (str(orig), msg),
+ "(%s.%s) %s" %
+ (orig.__class__.__module__, orig.__class__.__name__,
+ orig),
statement, params, orig
)
@@ -310,13 +314,12 @@ class DBAPIError(StatementError):
def __init__(self, statement, params, orig, connection_invalidated=False):
try:
text = str(orig)
- except (KeyboardInterrupt, SystemExit):
- raise
except Exception as e:
text = 'Error in str() of DB-API-generated exception: ' + str(e)
StatementError.__init__(
self,
- '(%s) %s' % (orig.__class__.__name__, text),
+ '(%s.%s) %s' % (
+ orig.__class__.__module__, orig.__class__.__name__, text, ),
statement,
params,
orig
diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py
index 121285ab3..c11795d37 100644
--- a/lib/sqlalchemy/ext/automap.py
+++ b/lib/sqlalchemy/ext/automap.py
@@ -243,7 +243,26 @@ follows:
one-to-many backref will be created on the referred class referring
to this class.
-4. The names of the relationships are determined using the
+4. If any of the columns that are part of the :class:`.ForeignKeyConstraint`
+ are not nullable (e.g. ``nullable=False``), a
+ :paramref:`~.relationship.cascade` keyword argument
+ of ``all, delete-orphan`` will be added to the keyword arguments to
+ be passed to the relationship or backref. If the
+ :class:`.ForeignKeyConstraint` reports that
+ :paramref:`.ForeignKeyConstraint.ondelete`
+ is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable
+ set of columns, the option :paramref:`~.relationship.passive_deletes`
+ flag is set to ``True`` in the set of relationship keyword arguments.
+ Note that not all backends support reflection of ON DELETE.
+
+ .. versionadded:: 1.0.0 - automap will detect non-nullable foreign key
+ constraints when producing a one-to-many relationship and establish
+ a default cascade of ``all, delete-orphan`` if so; additionally,
+ if the constraint specifies :paramref:`.ForeignKeyConstraint.ondelete`
+ of ``CASCADE`` for non-nullable or ``SET NULL`` for nullable columns,
+ the ``passive_deletes=True`` option is also added.
+
+5. The names of the relationships are determined using the
:paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and
:paramref:`.AutomapBase.prepare.name_for_collection_relationship`
callable functions. It is important to note that the default relationship
@@ -252,18 +271,18 @@ follows:
alternate class naming scheme, that's the name from which the relationship
name will be derived.
-5. The classes are inspected for an existing mapped property matching these
+6. The classes are inspected for an existing mapped property matching these
names. If one is detected on one side, but none on the other side,
:class:`.AutomapBase` attempts to create a relationship on the missing side,
then uses the :paramref:`.relationship.back_populates` parameter in order to
point the new relationship to the other side.
-6. In the usual case where no relationship is on either side,
+7. In the usual case where no relationship is on either side,
:meth:`.AutomapBase.prepare` produces a :func:`.relationship` on the
"many-to-one" side and matches it to the other using the
:paramref:`.relationship.backref` parameter.
-7. Production of the :func:`.relationship` and optionally the :func:`.backref`
+8. Production of the :func:`.relationship` and optionally the :func:`.backref`
is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship`
function, which can be supplied by the end-user in order to augment
the arguments passed to :func:`.relationship` or :func:`.backref` or to
@@ -877,6 +896,19 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config,
constraint
)
+ o2m_kws = {}
+ nullable = False not in set([fk.parent.nullable for fk in fks])
+ if not nullable:
+ o2m_kws['cascade'] = "all, delete-orphan"
+
+ if constraint.ondelete and \
+ constraint.ondelete.lower() == "cascade":
+ o2m_kws['passive_deletes'] = True
+ else:
+ if constraint.ondelete and \
+ constraint.ondelete.lower() == "set null":
+ o2m_kws['passive_deletes'] = True
+
create_backref = backref_name not in referred_cfg.properties
if relationship_name not in map_config.properties:
@@ -885,7 +917,8 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config,
automap_base,
interfaces.ONETOMANY, backref,
backref_name, referred_cls, local_cls,
- collection_class=collection_class)
+ collection_class=collection_class,
+ **o2m_kws)
else:
backref_obj = None
rel = generate_relationship(automap_base,
@@ -916,7 +949,8 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config,
fk.parent
for fk in constraint.elements],
back_populates=relationship_name,
- collection_class=collection_class)
+ collection_class=collection_class,
+ **o2m_kws)
if rel is not None:
referred_cfg.properties[backref_name] = rel
map_config.properties[
diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py
index 3cbc85c0c..2b611252a 100644
--- a/lib/sqlalchemy/ext/declarative/__init__.py
+++ b/lib/sqlalchemy/ext/declarative/__init__.py
@@ -873,8 +873,7 @@ the method without the need to copy it.
Columns generated by :class:`~.declared_attr` can also be
referenced by ``__mapper_args__`` to a limited degree, currently
-by ``polymorphic_on`` and ``version_id_col``, by specifying the
-classdecorator itself into the dictionary - the declarative extension
+by ``polymorphic_on`` and ``version_id_col``; the declarative extension
will resolve them at class construction time::
class MyMixin:
@@ -889,7 +888,6 @@ will resolve them at class construction time::
id = Column(Integer, primary_key=True)
-
Mixing in Relationships
~~~~~~~~~~~~~~~~~~~~~~~
@@ -922,6 +920,7 @@ reference a common target class via many-to-one::
__tablename__ = 'target'
id = Column(Integer, primary_key=True)
+
Using Advanced Relationship Arguments (e.g. ``primaryjoin``, etc.)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -1004,6 +1003,24 @@ requirement so that no reliance on copying is needed::
class Something(SomethingMixin, Base):
__tablename__ = "something"
+The :func:`.column_property` or other construct may refer
+to other columns from the mixin. These are copied ahead of time before
+the :class:`.declared_attr` is invoked::
+
+ class SomethingMixin(object):
+ x = Column(Integer)
+
+ y = Column(Integer)
+
+ @declared_attr
+ def x_plus_y(cls):
+ return column_property(cls.x + cls.y)
+
+
+.. versionchanged:: 1.0.0 mixin columns are copied to the final mapped class
+ so that :class:`.declared_attr` methods can access the actual column
+ that will be mapped.
+
Mixing in Association Proxy and Other Attributes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1087,19 +1104,20 @@ and ``TypeB`` classes.
Controlling table inheritance with mixins
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-The ``__tablename__`` attribute in conjunction with the hierarchy of
-classes involved in a declarative mixin scenario controls what type of
-table inheritance, if any,
-is configured by the declarative extension.
+The ``__tablename__`` attribute may be used to provide a function that
+will determine the name of the table used for each class in an inheritance
+hierarchy, as well as whether a class has its own distinct table.
-If the ``__tablename__`` is computed by a mixin, you may need to
-control which classes get the computed attribute in order to get the
-type of table inheritance you require.
+This is achieved using the :class:`.declared_attr` indicator in conjunction
+with a method named ``__tablename__()``. Declarative will always
+invoke :class:`.declared_attr` for the special names
+``__tablename__``, ``__mapper_args__`` and ``__table_args__``
+function **for each mapped class in the hierarchy**. The function therefore
+needs to expect to receive each class individually and to provide the
+correct answer for each.
-For example, if you had a mixin that computes ``__tablename__`` but
-where you wanted to use that mixin in a single table inheritance
-hierarchy, you can explicitly specify ``__tablename__`` as ``None`` to
-indicate that the class should not have a table mapped::
+For example, to create a mixin that gives every class a simple table
+name based on class name::
from sqlalchemy.ext.declarative import declared_attr
@@ -1118,15 +1136,10 @@ indicate that the class should not have a table mapped::
__mapper_args__ = {'polymorphic_identity': 'engineer'}
primary_language = Column(String(50))
-Alternatively, you can make the mixin intelligent enough to only
-return a ``__tablename__`` in the event that no table is already
-mapped in the inheritance hierarchy. To help with this, a
-:func:`~sqlalchemy.ext.declarative.has_inherited_table` helper
-function is provided that returns ``True`` if a parent class already
-has a mapped table.
-
-As an example, here's a mixin that will only allow single table
-inheritance::
+Alternatively, we can modify our ``__tablename__`` function to return
+``None`` for subclasses, using :func:`.has_inherited_table`. This has
+the effect of those subclasses being mapped with single table inheritance
+agaisnt the parent::
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.declarative import has_inherited_table
@@ -1147,6 +1160,64 @@ inheritance::
primary_language = Column(String(50))
__mapper_args__ = {'polymorphic_identity': 'engineer'}
+.. _mixin_inheritance_columns:
+
+Mixing in Columns in Inheritance Scenarios
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+In constrast to how ``__tablename__`` and other special names are handled when
+used with :class:`.declared_attr`, when we mix in columns and properties (e.g.
+relationships, column properties, etc.), the function is
+invoked for the **base class only** in the hierarchy. Below, only the
+``Person`` class will receive a column
+called ``id``; the mapping will fail on ``Engineer``, which is not given
+a primary key::
+
+ class HasId(object):
+ @declared_attr
+ def id(cls):
+ return Column('id', Integer, primary_key=True)
+
+ class Person(HasId, Base):
+ __tablename__ = 'person'
+ discriminator = Column('type', String(50))
+ __mapper_args__ = {'polymorphic_on': discriminator}
+
+ class Engineer(Person):
+ __tablename__ = 'engineer'
+ primary_language = Column(String(50))
+ __mapper_args__ = {'polymorphic_identity': 'engineer'}
+
+It is usually the case in joined-table inheritance that we want distinctly
+named columns on each subclass. However in this case, we may want to have
+an ``id`` column on every table, and have them refer to each other via
+foreign key. We can achieve this as a mixin by using the
+:attr:`.declared_attr.cascading` modifier, which indicates that the
+function should be invoked **for each class in the hierarchy**, just like
+it does for ``__tablename__``::
+
+ class HasId(object):
+ @declared_attr.cascading
+ def id(cls):
+ if has_inherited_table(cls):
+ return Column('id',
+ Integer,
+ ForeignKey('person.id'), primary_key=True)
+ else:
+ return Column('id', Integer, primary_key=True)
+
+ class Person(HasId, Base):
+ __tablename__ = 'person'
+ discriminator = Column('type', String(50))
+ __mapper_args__ = {'polymorphic_on': discriminator}
+
+ class Engineer(Person):
+ __tablename__ = 'engineer'
+ primary_language = Column(String(50))
+ __mapper_args__ = {'polymorphic_identity': 'engineer'}
+
+
+.. versionadded:: 1.0.0 added :attr:`.declared_attr.cascading`.
Combining Table/Mapper Arguments from Multiple Mixins
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py
index daf8bffb5..66fe05fd0 100644
--- a/lib/sqlalchemy/ext/declarative/api.py
+++ b/lib/sqlalchemy/ext/declarative/api.py
@@ -8,12 +8,13 @@
from ...schema import Table, MetaData
-from ...orm import synonym as _orm_synonym, mapper,\
+from ...orm import synonym as _orm_synonym, \
comparable_property,\
- interfaces, properties
+ interfaces, properties, attributes
from ...orm.util import polymorphic_union
from ...orm.base import _mapper_or_none
-from ...util import OrderedDict
+from ...util import OrderedDict, hybridmethod, hybridproperty
+from ... import util
from ... import exc
import weakref
@@ -21,7 +22,6 @@ from .base import _as_declarative, \
_declarative_constructor,\
_DeferredMapperConfig, _add_attribute
from .clsregistry import _class_resolver
-from . import clsregistry
def instrument_declarative(cls, registry, metadata):
@@ -157,12 +157,98 @@ class declared_attr(interfaces._MappedAttribute, property):
"""
- def __init__(self, fget, *arg, **kw):
- super(declared_attr, self).__init__(fget, *arg, **kw)
+ def __init__(self, fget, cascading=False):
+ super(declared_attr, self).__init__(fget)
self.__doc__ = fget.__doc__
+ self._cascading = cascading
def __get__(desc, self, cls):
- return desc.fget(cls)
+ # use the ClassManager for memoization of values. This is better than
+ # adding yet another attribute onto the class, or using weakrefs
+ # here which are slow and take up memory. It also allows us to
+ # warn for non-mapped use of declared_attr.
+
+ manager = attributes.manager_of_class(cls)
+ if manager is None:
+ util.warn(
+ "Unmanaged access of declarative attribute %s from "
+ "non-mapped class %s" %
+ (desc.fget.__name__, cls.__name__))
+ return desc.fget(cls)
+ try:
+ reg = manager.info['declared_attr_reg']
+ except KeyError:
+ raise exc.InvalidRequestError(
+ "@declared_attr called outside of the "
+ "declarative mapping process; is declarative_base() being "
+ "used correctly?")
+
+ if desc in reg:
+ return reg[desc]
+ else:
+ reg[desc] = obj = desc.fget(cls)
+ return obj
+
+ @hybridmethod
+ def _stateful(cls, **kw):
+ return _stateful_declared_attr(**kw)
+
+ @hybridproperty
+ def cascading(cls):
+ """Mark a :class:`.declared_attr` as cascading.
+
+ This is a special-use modifier which indicates that a column
+ or MapperProperty-based declared attribute should be configured
+ distinctly per mapped subclass, within a mapped-inheritance scenario.
+
+ Below, both MyClass as well as MySubClass will have a distinct
+ ``id`` Column object established::
+
+ class HasSomeAttribute(object):
+ @declared_attr.cascading
+ def some_id(cls):
+ if has_inherited_table(cls):
+ return Column(
+ ForeignKey('myclass.id'), primary_key=True)
+ else:
+ return Column(Integer, primary_key=True)
+
+ return Column('id', Integer, primary_key=True)
+
+ class MyClass(HasSomeAttribute, Base):
+ ""
+ # ...
+
+ class MySubClass(MyClass):
+ ""
+ # ...
+
+ The behavior of the above configuration is that ``MySubClass``
+ will refer to both its own ``id`` column as well as that of
+ ``MyClass`` underneath the attribute named ``some_id``.
+
+ .. seealso::
+
+ :ref:`declarative_inheritance`
+
+ :ref:`mixin_inheritance_columns`
+
+
+ """
+ return cls._stateful(cascading=True)
+
+
+class _stateful_declared_attr(declared_attr):
+ def __init__(self, **kw):
+ self.kw = kw
+
+ def _stateful(self, **kw):
+ new_kw = self.kw.copy()
+ new_kw.update(kw)
+ return _stateful_declared_attr(**new_kw)
+
+ def __call__(self, fn):
+ return declared_attr(fn, **self.kw)
def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
@@ -349,9 +435,11 @@ class AbstractConcreteBase(ConcreteBase):
``__declare_last__()`` function, which is essentially
a hook for the :meth:`.after_configured` event.
- :class:`.AbstractConcreteBase` does not produce a mapped
- table for the class itself. Compare to :class:`.ConcreteBase`,
- which does.
+ :class:`.AbstractConcreteBase` does produce a mapped class
+ for the base class, however it is not persisted to any table; it
+ is instead mapped directly to the "polymorphic" selectable directly
+ and is only used for selecting. Compare to :class:`.ConcreteBase`,
+ which does create a persisted table for the base class.
Example::
@@ -365,20 +453,72 @@ class AbstractConcreteBase(ConcreteBase):
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
+
__mapper_args__ = {
- 'polymorphic_identity':'manager',
- 'concrete':True}
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+ The abstract base class is handled by declarative in a special way;
+ at class configuration time, it behaves like a declarative mixin
+ or an ``__abstract__`` base class. Once classes are configured
+ and mappings are produced, it then gets mapped itself, but
+ after all of its decscendants. This is a very unique system of mapping
+ not found in any other SQLAlchemy system.
+
+ Using this approach, we can specify columns and properties
+ that will take place on mapped subclasses, in the way that
+ we normally do as in :ref:`declarative_mixins`::
+
+ class Company(Base):
+ __tablename__ = 'company'
+ id = Column(Integer, primary_key=True)
+
+ class Employee(AbstractConcreteBase, Base):
+ employee_id = Column(Integer, primary_key=True)
+
+ @declared_attr
+ def company_id(cls):
+ return Column(ForeignKey('company.id'))
+
+ @declared_attr
+ def company(cls):
+ return relationship("Company")
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+
+ name = Column(String(50))
+ manager_data = Column(String(40))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+ When we make use of our mappings however, both ``Manager`` and
+ ``Employee`` will have an independently usable ``.company`` attribute::
+
+ session.query(Employee).filter(Employee.company.has(id=5))
+
+ .. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase`
+ have been reworked to support relationships established directly
+ on the abstract base, without any special configurational steps.
+
"""
- __abstract__ = True
+ __no_table__ = True
@classmethod
def __declare_first__(cls):
- if hasattr(cls, '__mapper__'):
+ cls._sa_decl_prepare_nocascade()
+
+ @classmethod
+ def _sa_decl_prepare_nocascade(cls):
+ if getattr(cls, '__mapper__', None):
return
- clsregistry.add_class(cls.__name__, cls)
+ to_map = _DeferredMapperConfig.config_for_cls(cls)
+
# can't rely on 'self_and_descendants' here
# since technically an immediate subclass
# might not be mapped, but a subclass
@@ -392,11 +532,22 @@ class AbstractConcreteBase(ConcreteBase):
if mn is not None:
mappers.append(mn)
pjoin = cls._create_polymorphic_union(mappers)
- cls.__mapper__ = m = mapper(cls, pjoin, polymorphic_on=pjoin.c.type)
+
+ to_map.local_table = pjoin
+
+ m_args = to_map.mapper_args_fn or dict
+
+ def mapper_args():
+ args = m_args()
+ args['polymorphic_on'] = pjoin.c.type
+ return args
+ to_map.mapper_args_fn = mapper_args
+
+ m = to_map.map()
for scls in cls.__subclasses__():
sm = _mapper_or_none(scls)
- if sm.concrete and cls in scls.__bases__:
+ if sm and sm.concrete and cls in scls.__bases__:
sm._set_concrete_base(m)
diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py
index 94baeeb51..291608b6c 100644
--- a/lib/sqlalchemy/ext/declarative/base.py
+++ b/lib/sqlalchemy/ext/declarative/base.py
@@ -19,6 +19,9 @@ from ... import event
from . import clsregistry
import collections
import weakref
+from sqlalchemy.orm import instrumentation
+
+declared_attr = declarative_props = None
def _declared_mapping_info(cls):
@@ -32,322 +35,407 @@ def _declared_mapping_info(cls):
return None
+def _get_immediate_cls_attr(cls, attrname):
+ """return an attribute of the class that is either present directly
+ on the class, e.g. not on a superclass, or is from a superclass but
+ this superclass is a mixin, that is, not a descendant of
+ the declarative base.
+
+ This is used to detect attributes that indicate something about
+ a mapped class independently from any mapped classes that it may
+ inherit from.
+
+ """
+ for base in cls.__mro__:
+ _is_declarative_inherits = hasattr(base, '_decl_class_registry')
+ if attrname in base.__dict__:
+ value = getattr(base, attrname)
+ if (base is cls or
+ (base in cls.__bases__ and not _is_declarative_inherits)):
+ return value
+ else:
+ return None
+
+
def _as_declarative(cls, classname, dict_):
- from .api import declared_attr
+ global declared_attr, declarative_props
+ if declared_attr is None:
+ from .api import declared_attr
+ declarative_props = (declared_attr, util.classproperty)
- # dict_ will be a dictproxy, which we can't write to, and we need to!
- dict_ = dict(dict_)
+ if _get_immediate_cls_attr(cls, '__abstract__'):
+ return
- column_copies = {}
- potential_columns = {}
+ _MapperConfig.setup_mapping(cls, classname, dict_)
- mapper_args_fn = None
- table_args = inherited_table_args = None
- tablename = None
- declarative_props = (declared_attr, util.classproperty)
+class _MapperConfig(object):
- for base in cls.__mro__:
- _is_declarative_inherits = hasattr(base, '_decl_class_registry')
+ @classmethod
+ def setup_mapping(cls, cls_, classname, dict_):
+ defer_map = _get_immediate_cls_attr(
+ cls_, '_sa_decl_prepare_nocascade') or \
+ hasattr(cls_, '_sa_decl_prepare')
- if '__declare_last__' in base.__dict__:
- @event.listens_for(mapper, "after_configured")
- def go():
- cls.__declare_last__()
- if '__declare_first__' in base.__dict__:
- @event.listens_for(mapper, "before_configured")
- def go():
- cls.__declare_first__()
- if '__abstract__' in base.__dict__ and base.__abstract__:
- if (base is cls or
- (base in cls.__bases__ and not _is_declarative_inherits)):
- return
+ if defer_map:
+ cfg_cls = _DeferredMapperConfig
+ else:
+ cfg_cls = _MapperConfig
+ cfg_cls(cls_, classname, dict_)
- class_mapped = _declared_mapping_info(base) is not None
+ def __init__(self, cls_, classname, dict_):
- for name, obj in vars(base).items():
- if name == '__mapper_args__':
- if not mapper_args_fn and (
- not class_mapped or
- isinstance(obj, declarative_props)
- ):
- # don't even invoke __mapper_args__ until
- # after we've determined everything about the
- # mapped table.
- # make a copy of it so a class-level dictionary
- # is not overwritten when we update column-based
- # arguments.
- mapper_args_fn = lambda: dict(cls.__mapper_args__)
- elif name == '__tablename__':
- if not tablename and (
- not class_mapped or
- isinstance(obj, declarative_props)
- ):
- tablename = cls.__tablename__
- elif name == '__table_args__':
- if not table_args and (
- not class_mapped or
- isinstance(obj, declarative_props)
- ):
- table_args = cls.__table_args__
- if not isinstance(table_args, (tuple, dict, type(None))):
- raise exc.ArgumentError(
- "__table_args__ value must be a tuple, "
- "dict, or None")
- if base is not cls:
- inherited_table_args = True
- elif class_mapped:
- if isinstance(obj, declarative_props):
- util.warn("Regular (i.e. not __special__) "
- "attribute '%s.%s' uses @declared_attr, "
- "but owning class %s is mapped - "
- "not applying to subclass %s."
- % (base.__name__, name, base, cls))
- continue
- elif base is not cls:
- # we're a mixin.
- if isinstance(obj, Column):
- if getattr(cls, name) is not obj:
- # if column has been overridden
- # (like by the InstrumentedAttribute of the
- # superclass), skip
+ self.cls = cls_
+
+ # dict_ will be a dictproxy, which we can't write to, and we need to!
+ self.dict_ = dict(dict_)
+ self.classname = classname
+ self.mapped_table = None
+ self.properties = util.OrderedDict()
+ self.declared_columns = set()
+ self.column_copies = {}
+ self._setup_declared_events()
+
+ # register up front, so that @declared_attr can memoize
+ # function evaluations in .info
+ manager = instrumentation.register_class(self.cls)
+ manager.info['declared_attr_reg'] = {}
+
+ self._scan_attributes()
+
+ clsregistry.add_class(self.classname, self.cls)
+
+ self._extract_mappable_attributes()
+
+ self._extract_declared_columns()
+
+ self._setup_table()
+
+ self._setup_inheritance()
+
+ self._early_mapping()
+
+ def _early_mapping(self):
+ self.map()
+
+ def _setup_declared_events(self):
+ if _get_immediate_cls_attr(self.cls, '__declare_last__'):
+ @event.listens_for(mapper, "after_configured")
+ def after_configured():
+ self.cls.__declare_last__()
+
+ if _get_immediate_cls_attr(self.cls, '__declare_first__'):
+ @event.listens_for(mapper, "before_configured")
+ def before_configured():
+ self.cls.__declare_first__()
+
+ def _scan_attributes(self):
+ cls = self.cls
+ dict_ = self.dict_
+ column_copies = self.column_copies
+ mapper_args_fn = None
+ table_args = inherited_table_args = None
+ tablename = None
+
+ for base in cls.__mro__:
+ class_mapped = base is not cls and \
+ _declared_mapping_info(base) is not None and \
+ not _get_immediate_cls_attr(base, '_sa_decl_prepare_nocascade')
+
+ if not class_mapped and base is not cls:
+ self._produce_column_copies(base)
+
+ for name, obj in vars(base).items():
+ if name == '__mapper_args__':
+ if not mapper_args_fn and (
+ not class_mapped or
+ isinstance(obj, declarative_props)
+ ):
+ # don't even invoke __mapper_args__ until
+ # after we've determined everything about the
+ # mapped table.
+ # make a copy of it so a class-level dictionary
+ # is not overwritten when we update column-based
+ # arguments.
+ mapper_args_fn = lambda: dict(cls.__mapper_args__)
+ elif name == '__tablename__':
+ if not tablename and (
+ not class_mapped or
+ isinstance(obj, declarative_props)
+ ):
+ tablename = cls.__tablename__
+ elif name == '__table_args__':
+ if not table_args and (
+ not class_mapped or
+ isinstance(obj, declarative_props)
+ ):
+ table_args = cls.__table_args__
+ if not isinstance(
+ table_args, (tuple, dict, type(None))):
+ raise exc.ArgumentError(
+ "__table_args__ value must be a tuple, "
+ "dict, or None")
+ if base is not cls:
+ inherited_table_args = True
+ elif class_mapped:
+ if isinstance(obj, declarative_props):
+ util.warn("Regular (i.e. not __special__) "
+ "attribute '%s.%s' uses @declared_attr, "
+ "but owning class %s is mapped - "
+ "not applying to subclass %s."
+ % (base.__name__, name, base, cls))
+ continue
+ elif base is not cls:
+ # we're a mixin, abstract base, or something that is
+ # acting like that for now.
+ if isinstance(obj, Column):
+ # already copied columns to the mapped class.
continue
- if obj.foreign_keys:
+ elif isinstance(obj, MapperProperty):
raise exc.InvalidRequestError(
- "Columns with foreign keys to other columns "
- "must be declared as @declared_attr callables "
- "on declarative mixin classes. ")
- if name not in dict_ and not (
- '__table__' in dict_ and
- (obj.name or name) in dict_['__table__'].c
- ) and name not in potential_columns:
- potential_columns[name] = \
- column_copies[obj] = \
- obj.copy()
- column_copies[obj]._creation_order = \
- obj._creation_order
- elif isinstance(obj, MapperProperty):
+ "Mapper properties (i.e. deferred,"
+ "column_property(), relationship(), etc.) must "
+ "be declared as @declared_attr callables "
+ "on declarative mixin classes.")
+ elif isinstance(obj, declarative_props):
+ oldclassprop = isinstance(obj, util.classproperty)
+ if not oldclassprop and obj._cascading:
+ dict_[name] = column_copies[obj] = \
+ ret = obj.__get__(obj, cls)
+ else:
+ if oldclassprop:
+ util.warn_deprecated(
+ "Use of sqlalchemy.util.classproperty on "
+ "declarative classes is deprecated.")
+ dict_[name] = column_copies[obj] = \
+ ret = getattr(cls, name)
+ if isinstance(ret, (Column, MapperProperty)) and \
+ ret.doc is None:
+ ret.doc = obj.__doc__
+
+ if inherited_table_args and not tablename:
+ table_args = None
+
+ self.table_args = table_args
+ self.tablename = tablename
+ self.mapper_args_fn = mapper_args_fn
+
+ def _produce_column_copies(self, base):
+ cls = self.cls
+ dict_ = self.dict_
+ column_copies = self.column_copies
+ # copy mixin columns to the mapped class
+ for name, obj in vars(base).items():
+ if isinstance(obj, Column):
+ if getattr(cls, name) is not obj:
+ # if column has been overridden
+ # (like by the InstrumentedAttribute of the
+ # superclass), skip
+ continue
+ elif obj.foreign_keys:
raise exc.InvalidRequestError(
- "Mapper properties (i.e. deferred,"
- "column_property(), relationship(), etc.) must "
- "be declared as @declared_attr callables "
- "on declarative mixin classes.")
- elif isinstance(obj, declarative_props):
- dict_[name] = ret = \
- column_copies[obj] = getattr(cls, name)
- if isinstance(ret, (Column, MapperProperty)) and \
- ret.doc is None:
- ret.doc = obj.__doc__
-
- # apply inherited columns as we should
- for k, v in potential_columns.items():
- dict_[k] = v
-
- if inherited_table_args and not tablename:
- table_args = None
-
- clsregistry.add_class(classname, cls)
- our_stuff = util.OrderedDict()
-
- for k in list(dict_):
-
- # TODO: improve this ? all dunders ?
- if k in ('__table__', '__tablename__', '__mapper_args__'):
- continue
-
- value = dict_[k]
- if isinstance(value, declarative_props):
- value = getattr(cls, k)
-
- elif isinstance(value, QueryableAttribute) and \
- value.class_ is not cls and \
- value.key != k:
- # detect a QueryableAttribute that's already mapped being
- # assigned elsewhere in userland, turn into a synonym()
- value = synonym(value.key)
- setattr(cls, k, value)
-
- if (isinstance(value, tuple) and len(value) == 1 and
- isinstance(value[0], (Column, MapperProperty))):
- util.warn("Ignoring declarative-like tuple value of attribute "
- "%s: possibly a copy-and-paste error with a comma "
- "left at the end of the line?" % k)
- continue
- if not isinstance(value, (Column, MapperProperty)):
- if not k.startswith('__'):
- dict_.pop(k)
- setattr(cls, k, value)
- continue
- if k == 'metadata':
- raise exc.InvalidRequestError(
- "Attribute name 'metadata' is reserved "
- "for the MetaData instance when using a "
- "declarative base class."
- )
- prop = clsregistry._deferred_relationship(cls, value)
- our_stuff[k] = prop
-
- # set up attributes in the order they were created
- our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)
-
- # extract columns from the class dict
- declared_columns = set()
- name_to_prop_key = collections.defaultdict(set)
- for key, c in list(our_stuff.items()):
- if isinstance(c, (ColumnProperty, CompositeProperty)):
- for col in c.columns:
- if isinstance(col, Column) and \
- col.table is None:
- _undefer_column_name(key, col)
- if not isinstance(c, CompositeProperty):
- name_to_prop_key[col.name].add(key)
- declared_columns.add(col)
- elif isinstance(c, Column):
- _undefer_column_name(key, c)
- name_to_prop_key[c.name].add(key)
- declared_columns.add(c)
- # if the column is the same name as the key,
- # remove it from the explicit properties dict.
- # the normal rules for assigning column-based properties
- # will take over, including precedence of columns
- # in multi-column ColumnProperties.
- if key == c.key:
- del our_stuff[key]
-
- for name, keys in name_to_prop_key.items():
- if len(keys) > 1:
- util.warn(
- "On class %r, Column object %r named directly multiple times, "
- "only one will be used: %s" %
- (classname, name, (", ".join(sorted(keys))))
- )
+ "Columns with foreign keys to other columns "
+ "must be declared as @declared_attr callables "
+ "on declarative mixin classes. ")
+ elif name not in dict_ and not (
+ '__table__' in dict_ and
+ (obj.name or name) in dict_['__table__'].c
+ ):
+ column_copies[obj] = copy_ = obj.copy()
+ copy_._creation_order = obj._creation_order
+ setattr(cls, name, copy_)
+ dict_[name] = copy_
- declared_columns = sorted(
- declared_columns, key=lambda c: c._creation_order)
- table = None
+ def _extract_mappable_attributes(self):
+ cls = self.cls
+ dict_ = self.dict_
- if hasattr(cls, '__table_cls__'):
- table_cls = util.unbound_method_to_callable(cls.__table_cls__)
- else:
- table_cls = Table
-
- if '__table__' not in dict_:
- if tablename is not None:
-
- args, table_kw = (), {}
- if table_args:
- if isinstance(table_args, dict):
- table_kw = table_args
- elif isinstance(table_args, tuple):
- if isinstance(table_args[-1], dict):
- args, table_kw = table_args[0:-1], table_args[-1]
- else:
- args = table_args
-
- autoload = dict_.get('__autoload__')
- if autoload:
- table_kw['autoload'] = True
-
- cls.__table__ = table = table_cls(
- tablename, cls.metadata,
- *(tuple(declared_columns) + tuple(args)),
- **table_kw)
- else:
- table = cls.__table__
- if declared_columns:
- for c in declared_columns:
- if not table.c.contains_column(c):
- raise exc.ArgumentError(
- "Can't add additional column %r when "
- "specifying __table__" % c.key
- )
+ our_stuff = self.properties
- if hasattr(cls, '__mapper_cls__'):
- mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
- else:
- mapper_cls = mapper
+ for k in list(dict_):
- for c in cls.__bases__:
- if _declared_mapping_info(c) is not None:
- inherits = c
- break
- else:
- inherits = None
+ if k in ('__table__', '__tablename__', '__mapper_args__'):
+ continue
- if table is None and inherits is None:
- raise exc.InvalidRequestError(
- "Class %r does not have a __table__ or __tablename__ "
- "specified and does not inherit from an existing "
- "table-mapped class." % cls
- )
- elif inherits:
- inherited_mapper = _declared_mapping_info(inherits)
- inherited_table = inherited_mapper.local_table
- inherited_mapped_table = inherited_mapper.mapped_table
-
- if table is None:
- # single table inheritance.
- # ensure no table args
- if table_args:
- raise exc.ArgumentError(
- "Can't place __table_args__ on an inherited class "
- "with no table."
+ value = dict_[k]
+ if isinstance(value, declarative_props):
+ value = getattr(cls, k)
+
+ elif isinstance(value, QueryableAttribute) and \
+ value.class_ is not cls and \
+ value.key != k:
+ # detect a QueryableAttribute that's already mapped being
+ # assigned elsewhere in userland, turn into a synonym()
+ value = synonym(value.key)
+ setattr(cls, k, value)
+
+ if (isinstance(value, tuple) and len(value) == 1 and
+ isinstance(value[0], (Column, MapperProperty))):
+ util.warn("Ignoring declarative-like tuple value of attribute "
+ "%s: possibly a copy-and-paste error with a comma "
+ "left at the end of the line?" % k)
+ continue
+ elif not isinstance(value, (Column, MapperProperty)):
+ # using @declared_attr for some object that
+ # isn't Column/MapperProperty; remove from the dict_
+ # and place the evaulated value onto the class.
+ if not k.startswith('__'):
+ dict_.pop(k)
+ setattr(cls, k, value)
+ continue
+ # we expect to see the name 'metadata' in some valid cases;
+ # however at this point we see it's assigned to something trying
+ # to be mapped, so raise for that.
+ elif k == 'metadata':
+ raise exc.InvalidRequestError(
+ "Attribute name 'metadata' is reserved "
+ "for the MetaData instance when using a "
+ "declarative base class."
+ )
+ prop = clsregistry._deferred_relationship(cls, value)
+ our_stuff[k] = prop
+
+ def _extract_declared_columns(self):
+ our_stuff = self.properties
+
+ # set up attributes in the order they were created
+ our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)
+
+ # extract columns from the class dict
+ declared_columns = self.declared_columns
+ name_to_prop_key = collections.defaultdict(set)
+ for key, c in list(our_stuff.items()):
+ if isinstance(c, (ColumnProperty, CompositeProperty)):
+ for col in c.columns:
+ if isinstance(col, Column) and \
+ col.table is None:
+ _undefer_column_name(key, col)
+ if not isinstance(c, CompositeProperty):
+ name_to_prop_key[col.name].add(key)
+ declared_columns.add(col)
+ elif isinstance(c, Column):
+ _undefer_column_name(key, c)
+ name_to_prop_key[c.name].add(key)
+ declared_columns.add(c)
+ # if the column is the same name as the key,
+ # remove it from the explicit properties dict.
+ # the normal rules for assigning column-based properties
+ # will take over, including precedence of columns
+ # in multi-column ColumnProperties.
+ if key == c.key:
+ del our_stuff[key]
+
+ for name, keys in name_to_prop_key.items():
+ if len(keys) > 1:
+ util.warn(
+ "On class %r, Column object %r named "
+ "directly multiple times, "
+ "only one will be used: %s" %
+ (self.classname, name, (", ".join(sorted(keys))))
)
- # add any columns declared here to the inherited table.
- for c in declared_columns:
- if c.primary_key:
- raise exc.ArgumentError(
- "Can't place primary key columns on an inherited "
- "class with no table."
- )
- if c.name in inherited_table.c:
- if inherited_table.c[c.name] is c:
- continue
- raise exc.ArgumentError(
- "Column '%s' on class %s conflicts with "
- "existing column '%s'" %
- (c, cls, inherited_table.c[c.name])
- )
- inherited_table.append_column(c)
- if inherited_mapped_table is not None and \
- inherited_mapped_table is not inherited_table:
- inherited_mapped_table._refresh_for_new_column(c)
-
- defer_map = hasattr(cls, '_sa_decl_prepare')
- if defer_map:
- cfg_cls = _DeferredMapperConfig
- else:
- cfg_cls = _MapperConfig
- mt = cfg_cls(mapper_cls,
- cls, table,
- inherits,
- declared_columns,
- column_copies,
- our_stuff,
- mapper_args_fn)
- if not defer_map:
- mt.map()
+ def _setup_table(self):
+ cls = self.cls
+ tablename = self.tablename
+ table_args = self.table_args
+ dict_ = self.dict_
+ declared_columns = self.declared_columns
-class _MapperConfig(object):
+ declared_columns = self.declared_columns = sorted(
+ declared_columns, key=lambda c: c._creation_order)
+ table = None
- mapped_table = None
-
- def __init__(self, mapper_cls,
- cls,
- table,
- inherits,
- declared_columns,
- column_copies,
- properties, mapper_args_fn):
- self.mapper_cls = mapper_cls
- self.cls = cls
+ if hasattr(cls, '__table_cls__'):
+ table_cls = util.unbound_method_to_callable(cls.__table_cls__)
+ else:
+ table_cls = Table
+
+ if '__table__' not in dict_:
+ if tablename is not None:
+
+ args, table_kw = (), {}
+ if table_args:
+ if isinstance(table_args, dict):
+ table_kw = table_args
+ elif isinstance(table_args, tuple):
+ if isinstance(table_args[-1], dict):
+ args, table_kw = table_args[0:-1], table_args[-1]
+ else:
+ args = table_args
+
+ autoload = dict_.get('__autoload__')
+ if autoload:
+ table_kw['autoload'] = True
+
+ cls.__table__ = table = table_cls(
+ tablename, cls.metadata,
+ *(tuple(declared_columns) + tuple(args)),
+ **table_kw)
+ else:
+ table = cls.__table__
+ if declared_columns:
+ for c in declared_columns:
+ if not table.c.contains_column(c):
+ raise exc.ArgumentError(
+ "Can't add additional column %r when "
+ "specifying __table__" % c.key
+ )
self.local_table = table
- self.inherits = inherits
- self.properties = properties
- self.mapper_args_fn = mapper_args_fn
- self.declared_columns = declared_columns
- self.column_copies = column_copies
+
+ def _setup_inheritance(self):
+ table = self.local_table
+ cls = self.cls
+ table_args = self.table_args
+ declared_columns = self.declared_columns
+ for c in cls.__bases__:
+ if _declared_mapping_info(c) is not None and \
+ not _get_immediate_cls_attr(
+ c, '_sa_decl_prepare_nocascade'):
+ self.inherits = c
+ break
+ else:
+ self.inherits = None
+
+ if table is None and self.inherits is None and \
+ not _get_immediate_cls_attr(cls, '__no_table__'):
+
+ raise exc.InvalidRequestError(
+ "Class %r does not have a __table__ or __tablename__ "
+ "specified and does not inherit from an existing "
+ "table-mapped class." % cls
+ )
+ elif self.inherits:
+ inherited_mapper = _declared_mapping_info(self.inherits)
+ inherited_table = inherited_mapper.local_table
+ inherited_mapped_table = inherited_mapper.mapped_table
+
+ if table is None:
+ # single table inheritance.
+ # ensure no table args
+ if table_args:
+ raise exc.ArgumentError(
+ "Can't place __table_args__ on an inherited class "
+ "with no table."
+ )
+ # add any columns declared here to the inherited table.
+ for c in declared_columns:
+ if c.primary_key:
+ raise exc.ArgumentError(
+ "Can't place primary key columns on an inherited "
+ "class with no table."
+ )
+ if c.name in inherited_table.c:
+ if inherited_table.c[c.name] is c:
+ continue
+ raise exc.ArgumentError(
+ "Column '%s' on class %s conflicts with "
+ "existing column '%s'" %
+ (c, cls, inherited_table.c[c.name])
+ )
+ inherited_table.append_column(c)
+ if inherited_mapped_table is not None and \
+ inherited_mapped_table is not inherited_table:
+ inherited_mapped_table._refresh_for_new_column(c)
def _prepare_mapper_arguments(self):
properties = self.properties
@@ -401,20 +489,31 @@ class _MapperConfig(object):
properties[k] = [col] + p.columns
result_mapper_args = mapper_args.copy()
result_mapper_args['properties'] = properties
- return result_mapper_args
+ self.mapper_args = result_mapper_args
def map(self):
- mapper_args = self._prepare_mapper_arguments()
- self.cls.__mapper__ = self.mapper_cls(
+ self._prepare_mapper_arguments()
+ if hasattr(self.cls, '__mapper_cls__'):
+ mapper_cls = util.unbound_method_to_callable(
+ self.cls.__mapper_cls__)
+ else:
+ mapper_cls = mapper
+
+ self.cls.__mapper__ = mp_ = mapper_cls(
self.cls,
self.local_table,
- **mapper_args
+ **self.mapper_args
)
+ del mp_.class_manager.info['declared_attr_reg']
+ return mp_
class _DeferredMapperConfig(_MapperConfig):
_configs = util.OrderedDict()
+ def _early_mapping(self):
+ pass
+
@property
def cls(self):
return self._cls()
@@ -466,7 +565,7 @@ class _DeferredMapperConfig(_MapperConfig):
def map(self):
self._configs.pop(self._cls, None)
- super(_DeferredMapperConfig, self).map()
+ return super(_DeferredMapperConfig, self).map()
def _add_attribute(cls, key, value):
diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py
index 4595b857a..3ef63a5ae 100644
--- a/lib/sqlalchemy/ext/declarative/clsregistry.py
+++ b/lib/sqlalchemy/ext/declarative/clsregistry.py
@@ -103,7 +103,12 @@ class _MultipleClassMarker(object):
self.on_remove()
def add_item(self, item):
- modules = set([cls().__module__ for cls in self.contents])
+ # protect against class registration race condition against
+ # asynchronous garbage collection calling _remove_item,
+ # [ticket:3208]
+ modules = set([
+ cls.__module__ for cls in
+ [ref() for ref in self.contents] if cls is not None])
if item.__module__ in modules:
util.warn(
"This declarative base already contains a class with the "
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index c2754d58f..356a8a3b9 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -861,11 +861,24 @@ def _instrument_class(cls):
"Can not instrument a built-in type. Use a "
"subclass, even a trivial one.")
+ roles, methods = _locate_roles_and_methods(cls)
+
+ _setup_canned_roles(cls, roles, methods)
+
+ _assert_required_roles(cls, roles, methods)
+
+ _set_collection_attributes(cls, roles, methods)
+
+
+def _locate_roles_and_methods(cls):
+ """search for _sa_instrument_role-decorated methods in
+ method resolution order, assign to roles.
+
+ """
+
roles = {}
methods = {}
- # search for _sa_instrument_role-decorated methods in
- # method resolution order, assign to roles
for supercls in cls.__mro__:
for name, method in vars(supercls).items():
if not util.callable(method):
@@ -890,14 +903,19 @@ def _instrument_class(cls):
assert op in ('fire_append_event', 'fire_remove_event')
after = op
if before:
- methods[name] = before[0], before[1], after
+ methods[name] = before + (after, )
elif after:
methods[name] = None, None, after
+ return roles, methods
+
- # see if this class has "canned" roles based on a known
- # collection type (dict, set, list). Apply those roles
- # as needed to the "roles" dictionary, and also
- # prepare "decorator" methods
+def _setup_canned_roles(cls, roles, methods):
+ """see if this class has "canned" roles based on a known
+ collection type (dict, set, list). Apply those roles
+ as needed to the "roles" dictionary, and also
+ prepare "decorator" methods
+
+ """
collection_type = util.duck_type_collection(cls)
if collection_type in __interfaces:
canned_roles, decorators = __interfaces[collection_type]
@@ -911,8 +929,12 @@ def _instrument_class(cls):
not hasattr(fn, '_sa_instrumented')):
setattr(cls, method, decorator(fn))
- # ensure all roles are present, and apply implicit instrumentation if
- # needed
+
+def _assert_required_roles(cls, roles, methods):
+ """ensure all roles are present, and apply implicit instrumentation if
+ needed
+
+ """
if 'appender' not in roles or not hasattr(cls, roles['appender']):
raise sa_exc.ArgumentError(
"Type %s must elect an appender method to be "
@@ -934,8 +956,12 @@ def _instrument_class(cls):
"Type %s must elect an iterator method to be "
"a collection class" % cls.__name__)
- # apply ad-hoc instrumentation from decorators, class-level defaults
- # and implicit role declarations
+
+def _set_collection_attributes(cls, roles, methods):
+ """apply ad-hoc instrumentation from decorators, class-level defaults
+ and implicit role declarations
+
+ """
for method_name, (before, argument, after) in methods.items():
setattr(cls, method_name,
_instrument_membership_mutator(getattr(cls, method_name),
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index c50a7b062..9ea0dd834 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -61,7 +61,8 @@ class InstrumentationEvents(event.Events):
@classmethod
def _listen(cls, event_key, propagate=True, **kw):
target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, event_key.fn
+ event_key.dispatch_target, event_key.identifier, \
+ event_key._listen_fn
def listen(target_cls, *arg):
listen_cls = target()
@@ -192,7 +193,8 @@ class InstanceEvents(event.Events):
@classmethod
def _listen(cls, event_key, raw=False, propagate=False, **kw):
target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, event_key.fn
+ event_key.dispatch_target, event_key.identifier, \
+ event_key._listen_fn
if not raw:
def wrap(state, *arg, **kw):
@@ -498,7 +500,8 @@ class MapperEvents(event.Events):
def _listen(
cls, event_key, raw=False, retval=False, propagate=False, **kw):
target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, event_key.fn
+ event_key.dispatch_target, event_key.identifier, \
+ event_key._listen_fn
if identifier in ("before_configured", "after_configured") and \
target is not mapperlib.Mapper:
@@ -1493,7 +1496,8 @@ class AttributeEvents(event.Events):
propagate=False):
target, identifier, fn = \
- event_key.dispatch_target, event_key.identifier, event_key.fn
+ event_key.dispatch_target, event_key.identifier, \
+ event_key._listen_fn
if active_history:
target.dispatch._active_history = True
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 984f05256..082dae054 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -426,6 +426,12 @@ class Mapper(InspectionAttr):
thus persisting the value to the ``discriminator`` column
in the database.
+ .. warning::
+
+ Currently, **only one discriminator column may be set**, typically
+ on the base-most class in the hierarchy. "Cascading" polymorphic
+ columns are not yet supported.
+
.. seealso::
:ref:`inheritance_toplevel`
@@ -1080,6 +1086,9 @@ class Mapper(InspectionAttr):
auto-session attachment logic.
"""
+
+ # when using declarative as of 1.0, the register_class has
+ # already happened from within declarative.
manager = attributes.manager_of_class(self.class_)
if self.non_primary:
@@ -1102,18 +1111,14 @@ class Mapper(InspectionAttr):
"create a non primary Mapper. clear_mappers() will "
"remove *all* current mappers from all classes." %
self.class_)
- # else:
- # a ClassManager may already exist as
- # ClassManager.instrument_attribute() creates
- # new managers for each subclass if they don't yet exist.
+
+ if manager is None:
+ manager = instrumentation.register_class(self.class_)
_mapper_registry[self] = True
self.dispatch.instrument_class(self, self.class_)
- if manager is None:
- manager = instrumentation.register_class(self.class_)
-
self.class_manager = manager
manager.mapper = self
@@ -2657,7 +2662,7 @@ def configure_mappers():
mapper._expire_memoizations()
mapper.dispatch.mapper_configured(
mapper, mapper.class_)
- except:
+ except Exception:
exc = sys.exc_info()[1]
if not hasattr(exc, '_configure_failed'):
mapper._configure_failed = exc
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 1288c910f..c4a9402fb 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -18,7 +18,7 @@ import operator
from itertools import groupby, chain
from .. import sql, util, exc as sa_exc, schema
from . import attributes, sync, exc as orm_exc, evaluator
-from .base import state_str, _attr_as_key
+from .base import state_str, _attr_as_key, _entity_descriptor
from ..sql import expression
from . import loading
@@ -560,9 +560,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
state, state_dict, col)
if value is None:
raise orm_exc.FlushError(
- "Can't delete from table "
+ "Can't delete from table %s "
"using NULL for primary "
- "key value")
+ "key value on column %s" % (table, col))
if update_version_id is not None and \
mapper.version_id_col in mapper._cols_by_table[table]:
@@ -1112,6 +1112,7 @@ class BulkUpdate(BulkUD):
super(BulkUpdate, self).__init__(query)
self.query._no_select_modifiers("update")
self.values = values
+ self.mapper = self.query._mapper_zero_or_none()
@classmethod
def factory(cls, query, synchronize_session, values):
@@ -1121,9 +1122,40 @@ class BulkUpdate(BulkUD):
False: BulkUpdate
}, synchronize_session, query, values)
+ def _resolve_string_to_expr(self, key):
+ if self.mapper and isinstance(key, util.string_types):
+ attr = _entity_descriptor(self.mapper, key)
+ return attr.__clause_element__()
+ else:
+ return key
+
+ def _resolve_key_to_attrname(self, key):
+ if self.mapper and isinstance(key, util.string_types):
+ attr = _entity_descriptor(self.mapper, key)
+ return attr.property.key
+ elif isinstance(key, attributes.InstrumentedAttribute):
+ return key.key
+ elif hasattr(key, '__clause_element__'):
+ key = key.__clause_element__()
+
+ if self.mapper and isinstance(key, expression.ColumnElement):
+ try:
+ attr = self.mapper._columntoproperty[key]
+ except orm_exc.UnmappedColumnError:
+ return None
+ else:
+ return attr.key
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Invalid expression type: %r" % key)
+
def _do_exec(self):
+ values = dict(
+ (self._resolve_string_to_expr(k), v)
+ for k, v in self.values.items()
+ )
update_stmt = sql.update(self.primary_table,
- self.context.whereclause, self.values)
+ self.context.whereclause, values)
self.result = self.query.session.execute(
update_stmt, params=self.query._params)
@@ -1169,9 +1201,10 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
def _additional_evaluators(self, evaluator_compiler):
self.value_evaluators = {}
for key, value in self.values.items():
- key = _attr_as_key(key)
- self.value_evaluators[key] = evaluator_compiler.process(
- expression._literal_as_binds(value))
+ key = self._resolve_key_to_attrname(key)
+ if key is not None:
+ self.value_evaluators[key] = evaluator_compiler.process(
+ expression._literal_as_binds(value))
def _do_post_synchronize(self):
session = self.query.session
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 60948293b..f07060825 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1145,7 +1145,8 @@ class Query(object):
@_generative()
def with_hint(self, selectable, text, dialect_name='*'):
- """Add an indexing hint for the given entity or selectable to
+ """Add an indexing or other executional context
+ hint for the given entity or selectable to
this :class:`.Query`.
Functionality is passed straight through to
@@ -1153,11 +1154,35 @@ class Query(object):
with the addition that ``selectable`` can be a
:class:`.Table`, :class:`.Alias`, or ORM entity / mapped class
/etc.
+
+ .. seealso::
+
+ :meth:`.Query.with_statement_hint`
+
"""
- selectable = inspect(selectable).selectable
+ if selectable is not None:
+ selectable = inspect(selectable).selectable
self._with_hints += ((selectable, text, dialect_name),)
+ def with_statement_hint(self, text, dialect_name='*'):
+ """add a statement hint to this :class:`.Select`.
+
+ This method is similar to :meth:`.Select.with_hint` except that
+ it does not require an individual table, and instead applies to the
+ statement as a whole.
+
+ This feature calls down into :meth:`.Select.with_statement_hint`.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.Query.with_hint`
+
+ """
+ return self.with_hint(None, text, dialect_name)
+
@_generative()
def execution_options(self, **kwargs):
""" Set non-SQL options which take effect during execution.
@@ -1810,6 +1835,11 @@ class Query(object):
left_entity = prop = None
+ if isinstance(onclause, interfaces.PropComparator):
+ of_type = getattr(onclause, '_of_type', None)
+ else:
+ of_type = None
+
if isinstance(onclause, util.string_types):
left_entity = self._joinpoint_zero()
@@ -1836,8 +1866,6 @@ class Query(object):
if isinstance(onclause, interfaces.PropComparator):
if right_entity is None:
- right_entity = onclause.property.mapper
- of_type = getattr(onclause, '_of_type', None)
if of_type:
right_entity = of_type
else:
@@ -1919,11 +1947,9 @@ class Query(object):
from_obj, r_info.selectable):
overlap = True
break
- elif sql_util.selectables_overlap(l_info.selectable,
- r_info.selectable):
- overlap = True
- if overlap and l_info.selectable is r_info.selectable:
+ if (overlap or not create_aliases) and \
+ l_info.selectable is r_info.selectable:
raise sa_exc.InvalidRequestError(
"Can't join table/selectable '%s' to itself" %
l_info.selectable)
@@ -2591,6 +2617,19 @@ class Query(object):
SELECT 1 FROM users WHERE users.name = :name_1
) AS anon_1
+ The EXISTS construct is usually used in the WHERE clause::
+
+ session.query(User.id).filter(q.exists()).scalar()
+
+ Note that some databases such as SQL Server don't allow an
+ EXISTS expression to be present in the columns clause of a
+ SELECT. To select a simple boolean value based on the exists
+ as a WHERE, use :func:`.literal`::
+
+ from sqlalchemy import literal
+
+ session.query(literal(True)).filter(q.exists()).scalar()
+
.. versionadded:: 0.8.1
"""
@@ -2718,9 +2757,25 @@ class Query(object):
Updates rows matched by this query in the database.
- :param values: a dictionary with attributes names as keys and literal
+ E.g.::
+
+ sess.query(User).filter(User.age == 25).\
+ update({User.age: User.age - 10}, synchronize_session='fetch')
+
+
+ sess.query(User).filter(User.age == 25).\
+ update({"age": User.age - 10}, synchronize_session='evaluate')
+
+
+ :param values: a dictionary with attributes names, or alternatively
+ mapped attributes or SQL expressions, as keys, and literal
values or sql expressions as values.
+ .. versionchanged:: 1.0.0 - string names in the values dictionary
+ are now resolved against the mapped entity; previously, these
+ strings were passed as literal column names with no mapper-level
+ translation.
+
:param synchronize_session: chooses the strategy to update the
attributes on objects in the session. Valid values are:
@@ -2758,7 +2813,7 @@ class Query(object):
which normally occurs upon :meth:`.Session.commit` or can be forced
by using :meth:`.Session.expire_all`.
- * As of 0.8, this method will support multiple table updates, as
+ * The method supports multiple table updates, as
detailed in :ref:`multi_table_updates`, and this behavior does
extend to support updates of joined-inheritance and other multiple
table mappings. However, the **join condition of an inheritance
@@ -2789,12 +2844,6 @@ class Query(object):
"""
- # TODO: value keys need to be mapped to corresponding sql cols and
- # instr.attr.s to string keys
- # TODO: updates of manytoone relationships need to be converted to
- # fk assignments
- # TODO: cascades need handling.
-
update_op = persistence.BulkUpdate.factory(
self, synchronize_session, values)
update_op.exec_()
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 56a33742d..86f1b3f82 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -16,6 +16,7 @@ and `secondaryjoin` aspects of :func:`.relationship`.
from __future__ import absolute_import
from .. import sql, util, exc as sa_exc, schema, log
+import weakref
from .util import CascadeOptions, _orm_annotate, _orm_deannotate
from . import dependency
from . import attributes
@@ -1532,6 +1533,7 @@ class RelationshipProperty(StrategizedProperty):
self._check_cascade_settings(self._cascade)
self._post_init()
self._generate_backref()
+ self._join_condition._warn_for_conflicting_sync_targets()
super(RelationshipProperty, self).do_init()
self._lazy_strategy = self._get_strategy((("lazy", "select"),))
@@ -2519,6 +2521,60 @@ class JoinCondition(object):
self.secondary_synchronize_pairs = \
self._deannotate_pairs(secondary_sync_pairs)
+ _track_overlapping_sync_targets = weakref.WeakKeyDictionary()
+
+ def _warn_for_conflicting_sync_targets(self):
+ if not self.support_sync:
+ return
+
+ # we would like to detect if we are synchronizing any column
+ # pairs in conflict with another relationship that wishes to sync
+ # an entirely different column to the same target. This is a
+ # very rare edge case so we will try to minimize the memory/overhead
+ # impact of this check
+ for from_, to_ in [
+ (from_, to_) for (from_, to_) in self.synchronize_pairs
+ ] + [
+ (from_, to_) for (from_, to_) in self.secondary_synchronize_pairs
+ ]:
+ # save ourselves a ton of memory and overhead by only
+ # considering columns that are subject to a overlapping
+ # FK constraints at the core level. This condition can arise
+ # if multiple relationships overlap foreign() directly, but
+ # we're going to assume it's typically a ForeignKeyConstraint-
+ # level configuration that benefits from this warning.
+ if len(to_.foreign_keys) < 2:
+ continue
+
+ if to_ not in self._track_overlapping_sync_targets:
+ self._track_overlapping_sync_targets[to_] = \
+ weakref.WeakKeyDictionary({self.prop: from_})
+ else:
+ other_props = []
+ prop_to_from = self._track_overlapping_sync_targets[to_]
+ for pr, fr_ in prop_to_from.items():
+ if pr.mapper in mapperlib._mapper_registry and \
+ fr_ is not from_ and \
+ pr not in self.prop._reverse_property:
+ other_props.append((pr, fr_))
+
+ if other_props:
+ util.warn(
+ "relationship '%s' will copy column %s to column %s, "
+ "which conflicts with relationship(s): %s. "
+ "Consider applying "
+ "viewonly=True to read-only relationships, or provide "
+ "a primaryjoin condition marking writable columns "
+ "with the foreign() annotation." % (
+ self.prop,
+ from_, to_,
+ ", ".join(
+ "'%s' (copies %s to %s)" % (pr, fr_, to_)
+ for (pr, fr_) in other_props)
+ )
+ )
+ self._track_overlapping_sync_targets[to_][self.prop] = from_
+
@util.memoized_property
def remote_columns(self):
return self._gather_join_annotations("remote")
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 1611688b0..ef911824c 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -294,7 +294,7 @@ class SessionTransaction(object):
for s in self.session.identity_map.all_states():
s._expire(s.dict, self.session.identity_map._modified)
for s in self._deleted:
- s.session_id = None
+ s._detach()
self._deleted.clear()
elif self.nested:
self._parent._new.update(self._new)
@@ -644,14 +644,8 @@ class Session(_SessionClassMethods):
SessionExtension._adapt_listener(self, ext)
if binds is not None:
- for mapperortable, bind in binds.items():
- insp = inspect(mapperortable)
- if insp.is_selectable:
- self.bind_table(mapperortable, bind)
- elif insp.is_mapper:
- self.bind_mapper(mapperortable, bind)
- else:
- assert False
+ for key, bind in binds.items():
+ self._add_bind(key, bind)
if not self.autocommit:
self.begin()
@@ -1029,40 +1023,47 @@ class Session(_SessionClassMethods):
# TODO: + crystallize + document resolution order
# vis. bind_mapper/bind_table
- def bind_mapper(self, mapper, bind):
- """Bind operations for a mapper to a Connectable.
-
- mapper
- A mapper instance or mapped class
+ def _add_bind(self, key, bind):
+ try:
+ insp = inspect(key)
+ except sa_exc.NoInspectionAvailable:
+ if not isinstance(key, type):
+ raise exc.ArgumentError(
+ "Not acceptable bind target: %s" %
+ key)
+ else:
+ self.__binds[key] = bind
+ else:
+ if insp.is_selectable:
+ self.__binds[insp] = bind
+ elif insp.is_mapper:
+ self.__binds[insp.class_] = bind
+ for selectable in insp._all_tables:
+ self.__binds[selectable] = bind
+ else:
+ raise exc.ArgumentError(
+ "Not acceptable bind target: %s" %
+ key)
- bind
- Any Connectable: a :class:`.Engine` or :class:`.Connection`.
+ def bind_mapper(self, mapper, bind):
+ """Associate a :class:`.Mapper` with a "bind", e.g. a :class:`.Engine`
+ or :class:`.Connection`.
- All subsequent operations involving this mapper will use the given
- `bind`.
+ The given mapper is added to a lookup used by the
+ :meth:`.Session.get_bind` method.
"""
- if isinstance(mapper, type):
- mapper = class_mapper(mapper)
-
- self.__binds[mapper.base_mapper] = bind
- for t in mapper._all_tables:
- self.__binds[t] = bind
+ self._add_bind(mapper, bind)
def bind_table(self, table, bind):
- """Bind operations on a Table to a Connectable.
+ """Associate a :class:`.Table` with a "bind", e.g. a :class:`.Engine`
+ or :class:`.Connection`.
- table
- A :class:`.Table` instance
-
- bind
- Any Connectable: a :class:`.Engine` or :class:`.Connection`.
-
- All subsequent operations involving this :class:`.Table` will use the
- given `bind`.
+ The given mapper is added to a lookup used by the
+ :meth:`.Session.get_bind` method.
"""
- self.__binds[table] = bind
+ self._add_bind(table, bind)
def get_bind(self, mapper=None, clause=None):
"""Return a "bind" to which this :class:`.Session` is bound.
@@ -1116,6 +1117,7 @@ class Session(_SessionClassMethods):
bound :class:`.MetaData`.
"""
+
if mapper is clause is None:
if self.bind:
return self.bind
@@ -1125,15 +1127,23 @@ class Session(_SessionClassMethods):
"Connection, and no context was provided to locate "
"a binding.")
- c_mapper = mapper is not None and _class_to_mapper(mapper) or None
+ if mapper is not None:
+ try:
+ mapper = inspect(mapper)
+ except sa_exc.NoInspectionAvailable:
+ if isinstance(mapper, type):
+ raise exc.UnmappedClassError(mapper)
+ else:
+ raise
- # manually bound?
if self.__binds:
- if c_mapper:
- if c_mapper.base_mapper in self.__binds:
- return self.__binds[c_mapper.base_mapper]
- elif c_mapper.mapped_table in self.__binds:
- return self.__binds[c_mapper.mapped_table]
+ if mapper:
+ for cls in mapper.class_.__mro__:
+ if cls in self.__binds:
+ return self.__binds[cls]
+ if clause is None:
+ clause = mapper.mapped_table
+
if clause is not None:
for t in sql_util.find_tables(clause, include_crud=True):
if t in self.__binds:
@@ -1145,12 +1155,12 @@ class Session(_SessionClassMethods):
if isinstance(clause, sql.expression.ClauseElement) and clause.bind:
return clause.bind
- if c_mapper and c_mapper.mapped_table.bind:
- return c_mapper.mapped_table.bind
+ if mapper and mapper.mapped_table.bind:
+ return mapper.mapped_table.bind
context = []
if mapper is not None:
- context.append('mapper %s' % c_mapper)
+ context.append('mapper %s' % mapper)
if clause is not None:
context.append('SQL expression')
@@ -1402,6 +1412,7 @@ class Session(_SessionClassMethods):
state._detach()
elif self.transaction:
self.transaction._deleted.pop(state, None)
+ state._detach()
def _register_newly_persistent(self, states):
for state in states:
@@ -2478,16 +2489,19 @@ def make_transient_to_detached(instance):
def object_session(instance):
- """Return the ``Session`` to which instance belongs.
+ """Return the :class:`.Session` to which the given instance belongs.
- If the instance is not a mapped instance, an error is raised.
+ This is essentially the same as the :attr:`.InstanceState.session`
+ accessor. See that attribute for details.
"""
try:
- return _state_session(attributes.instance_state(instance))
+ state = attributes.instance_state(instance)
except exc.NO_STATE:
raise exc.UnmappedInstanceError(instance)
+ else:
+ return _state_session(state)
_new_sessionid = util.counter()
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 3c12fda1a..560149de5 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -145,7 +145,16 @@ class InstanceState(interfaces.InspectionAttr):
@util.dependencies("sqlalchemy.orm.session")
def session(self, sessionlib):
"""Return the owning :class:`.Session` for this instance,
- or ``None`` if none available."""
+ or ``None`` if none available.
+
+ Note that the result here can in some cases be *different*
+ from that of ``obj in session``; an object that's been deleted
+ will report as not ``in session``, however if the transaction is
+ still in progress, this attribute will still refer to that session.
+ Only when the transaction is completed does the object become
+ fully detached under normal circumstances.
+
+ """
return sessionlib._state_session(self)
@property
@@ -258,8 +267,8 @@ class InstanceState(interfaces.InspectionAttr):
try:
return manager.original_init(*mixed[1:], **kwargs)
except:
- manager.dispatch.init_failure(self, args, kwargs)
- raise
+ with util.safe_reraise():
+ manager.dispatch.init_failure(self, args, kwargs)
def get_history(self, key, passive):
return self.manager[key].impl.get_history(self, self.dict, passive)
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 734f9d5e6..ad610a4ac 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -30,13 +30,10 @@ class CascadeOptions(frozenset):
'all', 'none', 'delete-orphan'])
_allowed_cascades = all_cascades
- def __new__(cls, arg):
- values = set([
- c for c
- in re.split('\s*,\s*', arg or "")
- if c
- ])
-
+ def __new__(cls, value_list):
+ if isinstance(value_list, str) or value_list is None:
+ return cls.from_string(value_list)
+ values = set(value_list)
if values.difference(cls._allowed_cascades):
raise sa_exc.ArgumentError(
"Invalid cascade option(s): %s" %
@@ -70,6 +67,14 @@ class CascadeOptions(frozenset):
",".join([x for x in sorted(self)])
)
+ @classmethod
+ def from_string(cls, arg):
+ values = [
+ c for c
+ in re.split('\s*,\s*', arg or "")
+ if c
+ ]
+ return cls(values)
def _validator_events(
desc, key, validator, include_removes, include_backrefs):
@@ -804,6 +809,16 @@ class _ORMJoin(expression.Join):
expression.Join.__init__(self, left, right, onclause, isouter)
+ if not prop and getattr(right_info, 'mapper', None) \
+ and right_info.mapper.single:
+ # if single inheritance target and we are using a manual
+ # or implicit ON clause, augment it the same way we'd augment the
+ # WHERE.
+ single_crit = right_info.mapper._single_table_criterion
+ if right_info.is_aliased_class:
+ single_crit = right_info._adapter.traverse(single_crit)
+ self.onclause = self.onclause & single_crit
+
def join(self, right, onclause=None, isouter=False, join_to_left=None):
return _ORMJoin(self, right, onclause, isouter)
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index bc9affe4a..a174df784 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -248,9 +248,7 @@ class Pool(log.Identified):
self.logger.debug("Closing connection %r", connection)
try:
self._dialect.do_close(connection)
- except (SystemExit, KeyboardInterrupt):
- raise
- except:
+ except Exception:
self.logger.error("Exception closing connection %r",
connection, exc_info=True)
@@ -441,8 +439,8 @@ class _ConnectionRecord(object):
try:
dbapi_connection = rec.get_connection()
except:
- rec.checkin()
- raise
+ with util.safe_reraise():
+ rec.checkin()
echo = pool._should_log_debug()
fairy = _ConnectionFairy(dbapi_connection, rec, echo)
rec.fairy_ref = weakref.ref(
@@ -569,12 +567,12 @@ def _finalize_fairy(connection, connection_record,
# Immediately close detached instances
if not connection_record:
pool._close_connection(connection)
- except Exception as e:
+ except BaseException as e:
pool.logger.error(
"Exception during reset or similar", exc_info=True)
if connection_record:
connection_record.invalidate(e=e)
- if isinstance(e, (SystemExit, KeyboardInterrupt)):
+ if not isinstance(e, Exception):
raise
if connection_record:
@@ -842,9 +840,7 @@ class SingletonThreadPool(Pool):
for conn in self._all_conns:
try:
conn.close()
- except (SystemExit, KeyboardInterrupt):
- raise
- except:
+ except Exception:
# pysqlite won't even let you close a conn from a thread
# that didn't create it
pass
@@ -962,8 +958,8 @@ class QueuePool(Pool):
try:
return self._create_connection()
except:
- self._dec_overflow()
- raise
+ with util.safe_reraise():
+ self._dec_overflow()
else:
return self._do_get()
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 4d013859c..351e08d0b 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -38,6 +38,7 @@ from .expression import (
false,
False_,
func,
+ funcfilter,
insert,
intersect,
intersect_all,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 5149fa4fe..5fa78ad0f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -24,12 +24,10 @@ To generate user-defined SQL strings, see
"""
import re
-from . import schema, sqltypes, operators, functions, \
- util as sql_util, visitors, elements, selectable, base
+from . import schema, sqltypes, operators, functions, visitors, \
+ elements, selectable, crud
from .. import util, exc
-import decimal
import itertools
-import operator
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -64,17 +62,6 @@ BIND_TEMPLATES = {
'named': ":%(name)s"
}
-REQUIRED = util.symbol('REQUIRED', """
-Placeholder for the value within a :class:`.BindParameter`
-which is required to be present when the statement is passed
-to :meth:`.Connection.execute`.
-
-This symbol is typically used when a :func:`.expression.insert`
-or :func:`.expression.update` statement is compiled without parameter
-values present.
-
-""")
-
OPERATORS = {
# binary
@@ -725,7 +712,6 @@ class SQLCompiler(Compiled):
for c in clauselist.clauses)
if s)
-
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
@@ -760,6 +746,12 @@ class SQLCompiler(Compiled):
)
)
+ def visit_funcfilter(self, funcfilter, **kwargs):
+ return "%s FILTER (WHERE %s)" % (
+ funcfilter.func._compiler_dispatch(self, **kwargs),
+ funcfilter.criterion._compiler_dispatch(self, **kwargs)
+ )
+
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
@@ -819,8 +811,9 @@ class SQLCompiler(Compiled):
text += " GROUP BY " + group_by
text += self.order_by_clause(cs, **kwargs)
- text += (cs._limit_clause is not None or cs._offset_clause is not None) and \
- self.limit_clause(cs) or ""
+ text += (cs._limit_clause is not None
+ or cs._offset_clause is not None) and \
+ self.limit_clause(cs, **kwargs) or ""
if self.ctes and \
compound_index == 0 and toplevel:
@@ -876,15 +869,15 @@ class SQLCompiler(Compiled):
isinstance(binary.right, elements.BindParameter):
kw['literal_binds'] = True
- operator = binary.operator
- disp = getattr(self, "visit_%s_binary" % operator.__name__, None)
+ operator_ = binary.operator
+ disp = getattr(self, "visit_%s_binary" % operator_.__name__, None)
if disp:
- return disp(binary, operator, **kw)
+ return disp(binary, operator_, **kw)
else:
try:
- opstring = OPERATORS[operator]
+ opstring = OPERATORS[operator_]
except KeyError:
- raise exc.UnsupportedCompilationError(self, operator)
+ raise exc.UnsupportedCompilationError(self, operator_)
else:
return self._generate_generic_binary(binary, opstring, **kw)
@@ -966,7 +959,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -977,7 +970,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -988,7 +981,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
@@ -999,7 +992,7 @@ class SQLCompiler(Compiled):
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
@@ -1331,6 +1324,9 @@ class SQLCompiler(Compiled):
def get_crud_hint_text(self, table, text):
return None
+ def get_statement_hint_text(self, hint_texts):
+ return " ".join(hint_texts)
+
def _transform_select_for_nested_joins(self, select):
"""Rewrite any "a JOIN (b JOIN c)" expression as
"a JOIN (select * from b JOIN c) AS anon", to support
@@ -1501,29 +1497,7 @@ class SQLCompiler(Compiled):
select, transformed_select)
return text
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
-
- if asfrom:
- froms = select._get_display_froms(
- explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
- else:
- froms = select._get_display_froms(
- explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
-
- new_correlate_froms = set(selectable._from_objects(*froms))
- all_correlate_froms = new_correlate_froms.union(correlate_froms)
-
- new_entry = {
- 'asfrom_froms': new_correlate_froms,
- 'iswrapper': iswrapper,
- 'correlate_froms': all_correlate_froms,
- 'selectable': select,
- }
- self.stack.append(new_entry)
+ froms = self._setup_select_stack(select, entry, asfrom, iswrapper)
column_clause_args = kwargs.copy()
column_clause_args.update({
@@ -1534,18 +1508,11 @@ class SQLCompiler(Compiled):
text = "SELECT " # we're off to a good start !
if select._hints:
- byfrom = dict([
- (from_, hinttext % {
- 'name': from_._compiler_dispatch(
- self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
- hint_text = self.get_select_hint_text(byfrom)
+ hint_text, byfrom = self._setup_select_hints(select)
if hint_text:
text += hint_text + " "
+ else:
+ byfrom = None
if select._prefixes:
text += self._generate_prefixes(
@@ -1566,6 +1533,70 @@ class SQLCompiler(Compiled):
if c is not None
]
+ text = self._compose_select_body(
+ text, select, inner_columns, froms, byfrom, kwargs)
+
+ if select._statement_hints:
+ per_dialect = [
+ ht for (dialect_name, ht)
+ in select._statement_hints
+ if dialect_name in ('*', self.dialect.name)
+ ]
+ if per_dialect:
+ text += " " + self.get_statement_hint_text(per_dialect)
+
+ if self.ctes and \
+ compound_index == 0 and toplevel:
+ text = self._render_cte_clause() + text
+
+ self.stack.pop(-1)
+
+ if asfrom and parens:
+ return "(" + text + ")"
+ else:
+ return text
+
+ def _setup_select_hints(self, select):
+ byfrom = dict([
+ (from_, hinttext % {
+ 'name': from_._compiler_dispatch(
+ self, ashint=True)
+ })
+ for (from_, dialect), hinttext in
+ select._hints.items()
+ if dialect in ('*', self.dialect.name)
+ ])
+ hint_text = self.get_select_hint_text(byfrom)
+ return hint_text, byfrom
+
+ def _setup_select_stack(self, select, entry, asfrom, iswrapper):
+ correlate_froms = entry['correlate_froms']
+ asfrom_froms = entry['asfrom_froms']
+
+ if asfrom:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms),
+ implicit_correlate_froms=())
+ else:
+ froms = select._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms)
+
+ new_correlate_froms = set(selectable._from_objects(*froms))
+ all_correlate_froms = new_correlate_froms.union(correlate_froms)
+
+ new_entry = {
+ 'asfrom_froms': new_correlate_froms,
+ 'iswrapper': iswrapper,
+ 'correlate_froms': all_correlate_froms,
+ 'selectable': select,
+ }
+ self.stack.append(new_entry)
+ return froms
+
+ def _compose_select_body(
+ self, text, select, inner_columns, froms, byfrom, kwargs):
text += ', '.join(inner_columns)
if froms:
@@ -1609,16 +1640,7 @@ class SQLCompiler(Compiled):
if select._for_update_arg is not None:
text += self.for_update_clause(select, **kwargs)
- if self.ctes and \
- compound_index == 0 and toplevel:
- text = self._render_cte_clause() + text
-
- self.stack.pop(-1)
-
- if asfrom and parens:
- return "(" + text + ")"
- else:
- return text
+ return text
def _generate_prefixes(self, stmt, prefixes, **kw):
clause = " ".join(
@@ -1708,9 +1730,9 @@ class SQLCompiler(Compiled):
def visit_insert(self, insert_stmt, **kw):
self.isinsert = True
- colparams = self._get_colparams(insert_stmt, **kw)
+ crud_params = crud._get_crud_params(self, insert_stmt, **kw)
- if not colparams and \
+ if not crud_params and \
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The '%s' dialect with current database "
@@ -1725,9 +1747,9 @@ class SQLCompiler(Compiled):
"version settings does not support "
"in-place multirow inserts." %
self.dialect.name)
- colparams_single = colparams[0]
+ crud_params_single = crud_params[0]
else:
- colparams_single = colparams
+ crud_params_single = crud_params
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -1758,9 +1780,9 @@ class SQLCompiler(Compiled):
text += table_text
- if colparams_single or not supports_default_values:
+ if crud_params_single or not supports_default_values:
text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in colparams_single])
+ for c in crud_params_single])
if self.returning or insert_stmt._returning:
self.returning = self.returning or insert_stmt._returning
@@ -1771,21 +1793,21 @@ class SQLCompiler(Compiled):
text += " " + returning_clause
if insert_stmt.select is not None:
- text += " %s" % self.process(insert_stmt.select, **kw)
- elif not colparams and supports_default_values:
+ text += " %s" % self.process(self._insert_from_select, **kw)
+ elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)" % (
- ', '.join(c[1] for c in colparam_set)
+ ', '.join(c[1] for c in crud_param_set)
)
- for colparam_set in colparams
+ for crud_param_set in crud_params
)
)
else:
text += " VALUES (%s)" % \
- ', '.join([c[1] for c in colparams])
+ ', '.join([c[1] for c in crud_params])
if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
@@ -1842,7 +1864,7 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- colparams = self._get_colparams(update_stmt, **kw)
+ crud_params = crud._get_crud_params(self, update_stmt, **kw)
if update_stmt._hints:
dialect_hints = dict([
@@ -1869,7 +1891,7 @@ class SQLCompiler(Compiled):
text += ', '.join(
c[0]._compiler_dispatch(self,
include_table=include_table) +
- '=' + c[1] for c in colparams
+ '=' + c[1] for c in crud_params
)
if self.returning or update_stmt._returning:
@@ -1905,380 +1927,9 @@ class SQLCompiler(Compiled):
return text
- def _create_crud_bind_param(self, col, value, required=False, name=None):
- if name is None:
- name = col.key
- bindparam = elements.BindParameter(name, value,
- type_=col.type, required=required)
- bindparam._is_crud = True
- return bindparam._compiler_dispatch(self)
-
@util.memoized_property
def _key_getters_for_crud_column(self):
- if self.isupdate and self.statement._extra_froms:
- # when extra tables are present, refer to the columns
- # in those extra tables as table-qualified, including in
- # dictionaries and when rendering bind param names.
- # the "main" table of the statement remains unqualified,
- # allowing the most compatibility with a non-multi-table
- # statement.
- _et = set(self.statement._extra_froms)
-
- def _column_as_key(key):
- str_key = elements._column_as_key(key)
- if hasattr(key, 'table') and key.table in _et:
- return (key.table.name, str_key)
- else:
- return str_key
-
- def _getattr_col_key(col):
- if col.table in _et:
- return (col.table.name, col.key)
- else:
- return col.key
-
- def _col_bind_name(col):
- if col.table in _et:
- return "%s_%s" % (col.table.name, col.key)
- else:
- return col.key
-
- else:
- _column_as_key = elements._column_as_key
- _getattr_col_key = _col_bind_name = operator.attrgetter("key")
-
- return _column_as_key, _getattr_col_key, _col_bind_name
-
- def _get_colparams(self, stmt, **kw):
- """create a set of tuples representing column/string pairs for use
- in an INSERT or UPDATE statement.
-
- Also generates the Compiled object's postfetch, prefetch, and
- returning column collections, used for default handling and ultimately
- populating the ResultProxy's prefetch_cols() and postfetch_cols()
- collections.
-
- """
-
- self.postfetch = []
- self.prefetch = []
- self.returning = []
-
- # no parameters in the statement, no parameters in the
- # compiled params - return binds for all columns
- if self.column_keys is None and stmt.parameters is None:
- return [
- (c, self._create_crud_bind_param(c,
- None, required=True))
- for c in stmt.table.columns
- ]
-
- if stmt._has_multi_parameters:
- stmt_parameters = stmt.parameters[0]
- else:
- stmt_parameters = stmt.parameters
-
- # getters - these are normally just column.key,
- # but in the case of mysql multi-table update, the rules for
- # .key must conditionally take tablename into account
- _column_as_key, _getattr_col_key, _col_bind_name = \
- self._key_getters_for_crud_column
-
- # if we have statement parameters - set defaults in the
- # compiled params
- if self.column_keys is None:
- parameters = {}
- else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in self.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
-
- # create a list of column assignment clauses as tuples
- values = []
-
- if stmt_parameters is not None:
- for k, v in stmt_parameters.items():
- colkey = _column_as_key(k)
- if colkey is not None:
- parameters.setdefault(colkey, v)
- else:
- # a non-Column expression on the left side;
- # add it to values() in an "as-is" state,
- # coercing right side to bound param
- if elements._is_literal(v):
- v = self.process(
- elements.BindParameter(None, v, type_=k.type),
- **kw)
- else:
- v = self.process(v.self_group(), **kw)
-
- values.append((k, v))
-
- need_pks = self.isinsert and \
- not self.inline and \
- not stmt._returning and \
- not stmt._has_multi_parameters
-
- implicit_returning = need_pks and \
- self.dialect.implicit_returning and \
- stmt.table.implicit_returning
-
- if self.isinsert:
- implicit_return_defaults = (implicit_returning and
- stmt._return_defaults)
- elif self.isupdate:
- implicit_return_defaults = (self.dialect.implicit_returning and
- stmt.table.implicit_returning and
- stmt._return_defaults)
- else:
- implicit_return_defaults = False
-
- if implicit_return_defaults:
- if stmt._return_defaults is True:
- implicit_return_defaults = set(stmt.table.c)
- else:
- implicit_return_defaults = set(stmt._return_defaults)
-
- postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
-
- check_columns = {}
-
- # special logic that only occurs for multi-table UPDATE
- # statements
- if self.isupdate and stmt._extra_froms and stmt_parameters:
- normalized_params = dict(
- (elements._clause_element_as_expr(c), param)
- for c, param in stmt_parameters.items()
- )
- affected_tables = set()
- for t in stmt._extra_froms:
- for c in t.c:
- if c in normalized_params:
- affected_tables.add(t)
- check_columns[_getattr_col_key(c)] = c
- value = normalized_params[c]
- if elements._is_literal(value):
- value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED,
- name=_col_bind_name(c))
- else:
- self.postfetch.append(c)
- value = self.process(value.self_group(), **kw)
- values.append((c, value))
- # determine tables which are actually
- # to be updated - process onupdate and
- # server_onupdate for these
- for t in affected_tables:
- for c in t.c:
- if c in normalized_params:
- continue
- elif (c.onupdate is not None and not
- c.onupdate.is_sequence):
- if c.onupdate.is_clause_element:
- values.append(
- (c, self.process(
- c.onupdate.arg.self_group(),
- **kw)
- )
- )
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(
- c, None, name=_col_bind_name(c)
- )
- )
- )
- self.prefetch.append(c)
- elif c.server_onupdate is not None:
- self.postfetch.append(c)
-
- if self.isinsert and stmt.select_names:
- # for an insert from select, we can only use names that
- # are given, so only select for those names.
- cols = (stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names)
- else:
- # iterate through all table columns to maintain
- # ordering, even for those cols that aren't included
- cols = stmt.table.columns
-
- for c in cols:
- col_key = _getattr_col_key(c)
- if col_key in parameters and col_key not in check_columns:
- value = parameters.pop(col_key)
- if elements._is_literal(value):
- value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED,
- name=_col_bind_name(c)
- if not stmt._has_multi_parameters
- else "%s_0" % _col_bind_name(c)
- )
- else:
- if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
- value = value._clone()
- value.type = c.type
-
- if c.primary_key and implicit_returning:
- self.returning.append(c)
- value = self.process(value.self_group(), **kw)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- value = self.process(value.self_group(), **kw)
- else:
- self.postfetch.append(c)
- value = self.process(value.self_group(), **kw)
- values.append((c, value))
-
- elif self.isinsert:
- if c.primary_key and \
- need_pks and \
- (
- implicit_returning or
- not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
- ):
-
- if implicit_returning:
- if c.default is not None:
- if c.default.is_sequence:
- if self.dialect.supports_sequences and \
- (not c.default.optional or
- not self.dialect.sequences_optional):
- proc = self.process(c.default, **kw)
- values.append((c, proc))
- self.returning.append(c)
- elif c.default.is_clause_element:
- values.append(
- (c, self.process(
- c.default.arg.self_group(), **kw))
- )
- self.returning.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- else:
- self.returning.append(c)
- else:
- if (
- (c.default is not None and
- (not c.default.is_sequence or
- self.dialect.supports_sequences)) or
- c is stmt.table._autoincrement_column and
- (self.dialect.supports_sequences or
- self.dialect.
- preexecute_autoincrement_sequences)
- ):
-
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
-
- self.prefetch.append(c)
-
- elif c.default is not None:
- if c.default.is_sequence:
- if self.dialect.supports_sequences and \
- (not c.default.optional or
- not self.dialect.sequences_optional):
- proc = self.process(c.default, **kw)
- values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- self.postfetch.append(c)
- elif c.default.is_clause_element:
- values.append(
- (c, self.process(
- c.default.arg.self_group(), **kw))
- )
-
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- # don't add primary key column to postfetch
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- elif c.server_default is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- elif not c.primary_key:
- self.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
-
- elif self.isupdate:
- if c.onupdate is not None and not c.onupdate.is_sequence:
- if c.onupdate.is_clause_element:
- values.append(
- (c, self.process(
- c.onupdate.arg.self_group(), **kw))
- )
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- else:
- self.postfetch.append(c)
- else:
- values.append(
- (c, self._create_crud_bind_param(c, None))
- )
- self.prefetch.append(c)
- elif c.server_onupdate is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
- else:
- self.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
- self.returning.append(c)
-
- if parameters and stmt_parameters:
- check = set(parameters).intersection(
- _column_as_key(k) for k in stmt.parameters
- ).difference(check_columns)
- if check:
- raise exc.CompileError(
- "Unconsumed column names: %s" %
- (", ".join("%s" % c for c in check))
- )
-
- if stmt._has_multi_parameters:
- values_0 = values
- values = [values]
-
- values.extend(
- [
- (
- c,
- (self._create_crud_bind_param(
- c, row[c.key],
- name="%s_%d" % (c.key, i + 1)
- ) if elements._is_literal(row[c.key])
- else self.process(
- row[c.key].self_group(), **kw))
- if c.key in row else param
- )
- for (c, param) in values_0
- ]
- for i, row in enumerate(stmt.parameters[1:])
- )
-
- return values
+ return crud._key_getters_for_crud_column(self)
def visit_delete(self, delete_stmt, **kw):
self.stack.append({'correlate_froms': set([delete_stmt.table]),
@@ -2462,17 +2113,18 @@ class DDLCompiler(Compiled):
constraints.extend([c for c in table._sorted_constraints
if c is not table.primary_key])
- return ", \n\t".join(p for p in
- (self.process(constraint)
- for constraint in constraints
- if (
- constraint._create_rule is None or
- constraint._create_rule(self))
- and (
- not self.dialect.supports_alter or
- not getattr(constraint, 'use_alter', False)
- )) if p is not None
- )
+ return ", \n\t".join(
+ p for p in
+ (self.process(constraint)
+ for constraint in constraints
+ if (
+ constraint._create_rule is None or
+ constraint._create_rule(self))
+ and (
+ not self.dialect.supports_alter or
+ not getattr(constraint, 'use_alter', False)
+ )) if p is not None
+ )
def visit_drop_table(self, drop):
return "\nDROP TABLE " + self.preparer.format_table(drop.element)
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
new file mode 100644
index 000000000..831d05be1
--- /dev/null
+++ b/lib/sqlalchemy/sql/crud.py
@@ -0,0 +1,530 @@
+# sql/crud.py
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Functions used by compiler.py to determine the parameters rendered
+within INSERT and UPDATE statements.
+
+"""
+from .. import util
+from .. import exc
+from . import elements
+import operator
+
+REQUIRED = util.symbol('REQUIRED', """
+Placeholder for the value within a :class:`.BindParameter`
+which is required to be present when the statement is passed
+to :meth:`.Connection.execute`.
+
+This symbol is typically used when a :func:`.expression.insert`
+or :func:`.expression.update` statement is compiled without parameter
+values present.
+
+""")
+
+
+def _get_crud_params(compiler, stmt, **kw):
+ """create a set of tuples representing column/string pairs for use
+ in an INSERT or UPDATE statement.
+
+ Also generates the Compiled object's postfetch, prefetch, and
+ returning column collections, used for default handling and ultimately
+ populating the ResultProxy's prefetch_cols() and postfetch_cols()
+ collections.
+
+ """
+
+ compiler.postfetch = []
+ compiler.prefetch = []
+ compiler.returning = []
+
+ # no parameters in the statement, no parameters in the
+ # compiled params - return binds for all columns
+ if compiler.column_keys is None and stmt.parameters is None:
+ return [
+ (c, _create_bind_param(
+ compiler, c, None, required=True))
+ for c in stmt.table.columns
+ ]
+
+ if stmt._has_multi_parameters:
+ stmt_parameters = stmt.parameters[0]
+ else:
+ stmt_parameters = stmt.parameters
+
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ _column_as_key, _getattr_col_key, _col_bind_name = \
+ _key_getters_for_crud_column(compiler)
+
+ # if we have statement parameters - set defaults in the
+ # compiled params
+ if compiler.column_keys is None:
+ parameters = {}
+ else:
+ parameters = dict((_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or
+ key not in stmt_parameters)
+
+ # create a list of column assignment clauses as tuples
+ values = []
+
+ if stmt_parameters is not None:
+ _get_stmt_parameters_params(
+ compiler,
+ parameters, stmt_parameters, _column_as_key, values, kw)
+
+ check_columns = {}
+
+ # special logic that only occurs for multi-table UPDATE
+ # statements
+ if compiler.isupdate and stmt._extra_froms and stmt_parameters:
+ _get_multitable_params(
+ compiler, stmt, stmt_parameters, check_columns,
+ _col_bind_name, _getattr_col_key, values, kw)
+
+ if compiler.isinsert and stmt.select_names:
+ _scan_insert_from_select_cols(
+ compiler, stmt, parameters,
+ _getattr_col_key, _column_as_key,
+ _col_bind_name, check_columns, values, kw)
+ else:
+ _scan_cols(
+ compiler, stmt, parameters,
+ _getattr_col_key, _column_as_key,
+ _col_bind_name, check_columns, values, kw)
+
+ if parameters and stmt_parameters:
+ check = set(parameters).intersection(
+ _column_as_key(k) for k in stmt.parameters
+ ).difference(check_columns)
+ if check:
+ raise exc.CompileError(
+ "Unconsumed column names: %s" %
+ (", ".join("%s" % c for c in check))
+ )
+
+ if stmt._has_multi_parameters:
+ values = _extend_values_for_multiparams(compiler, stmt, values, kw)
+
+ return values
+
+
+def _create_bind_param(
+ compiler, col, value, process=True, required=False, name=None):
+ if name is None:
+ name = col.key
+ bindparam = elements.BindParameter(name, value,
+ type_=col.type, required=required)
+ bindparam._is_crud = True
+ if process:
+ bindparam = bindparam._compiler_dispatch(compiler)
+ return bindparam
+
+
+def _key_getters_for_crud_column(compiler):
+ if compiler.isupdate and compiler.statement._extra_froms:
+ # when extra tables are present, refer to the columns
+ # in those extra tables as table-qualified, including in
+ # dictionaries and when rendering bind param names.
+ # the "main" table of the statement remains unqualified,
+ # allowing the most compatibility with a non-multi-table
+ # statement.
+ _et = set(compiler.statement._extra_froms)
+
+ def _column_as_key(key):
+ str_key = elements._column_as_key(key)
+ if hasattr(key, 'table') and key.table in _et:
+ return (key.table.name, str_key)
+ else:
+ return str_key
+
+ def _getattr_col_key(col):
+ if col.table in _et:
+ return (col.table.name, col.key)
+ else:
+ return col.key
+
+ def _col_bind_name(col):
+ if col.table in _et:
+ return "%s_%s" % (col.table.name, col.key)
+ else:
+ return col.key
+
+ else:
+ _column_as_key = elements._column_as_key
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+
+ return _column_as_key, _getattr_col_key, _col_bind_name
+
+
+def _scan_insert_from_select_cols(
+ compiler, stmt, parameters, _getattr_col_key,
+ _column_as_key, _col_bind_name, check_columns, values, kw):
+
+ need_pks, implicit_returning, \
+ implicit_return_defaults, postfetch_lastrowid = \
+ _get_returning_modifiers(compiler, stmt)
+
+ cols = [stmt.table.c[_column_as_key(name)]
+ for name in stmt.select_names]
+
+ compiler._insert_from_select = stmt.select
+
+ add_select_cols = []
+ if stmt.include_insert_from_select_defaults:
+ col_set = set(cols)
+ for col in stmt.table.columns:
+ if col not in col_set and col.default:
+ cols.append(col)
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ parameters.pop(col_key)
+ values.append((c, None))
+ else:
+ _append_param_insert_select_hasdefault(
+ compiler, stmt, c, add_select_cols, kw)
+
+ if add_select_cols:
+ values.extend(add_select_cols)
+ compiler._insert_from_select = compiler._insert_from_select._generate()
+ compiler._insert_from_select._raw_columns += tuple(
+ expr for col, expr in add_select_cols)
+
+
+def _scan_cols(
+ compiler, stmt, parameters, _getattr_col_key,
+ _column_as_key, _col_bind_name, check_columns, values, kw):
+
+ need_pks, implicit_returning, \
+ implicit_return_defaults, postfetch_lastrowid = \
+ _get_returning_modifiers(compiler, stmt)
+
+ cols = stmt.table.columns
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+
+ _append_param_parameter(
+ compiler, stmt, c, col_key, parameters, _col_bind_name,
+ implicit_returning, implicit_return_defaults, values, kw)
+
+ elif compiler.isinsert:
+ if c.primary_key and \
+ need_pks and \
+ (
+ implicit_returning or
+ not postfetch_lastrowid or
+ c is not stmt.table._autoincrement_column
+ ):
+
+ if implicit_returning:
+ _append_param_insert_pk_returning(
+ compiler, stmt, c, values, kw)
+ else:
+ _append_param_insert_pk(compiler, stmt, c, values, kw)
+
+ elif c.default is not None:
+
+ _append_param_insert_hasdefault(
+ compiler, stmt, c, implicit_return_defaults,
+ values, kw)
+
+ elif c.server_default is not None:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ compiler.postfetch.append(c)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+ elif compiler.isupdate:
+ _append_param_update(
+ compiler, stmt, c, implicit_return_defaults, values, kw)
+
+
+def _append_param_parameter(
+ compiler, stmt, c, col_key, parameters, _col_bind_name,
+ implicit_returning, implicit_return_defaults, values, kw):
+ value = parameters.pop(col_key)
+ if elements._is_literal(value):
+ value = _create_bind_param(
+ compiler, c, value, required=value is REQUIRED,
+ name=_col_bind_name(c)
+ if not stmt._has_multi_parameters
+ else "%s_0" % _col_bind_name(c)
+ )
+ else:
+ if isinstance(value, elements.BindParameter) and \
+ value.type._isnull:
+ value = value._clone()
+ value.type = c.type
+
+ if c.primary_key and implicit_returning:
+ compiler.returning.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ else:
+ compiler.postfetch.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ values.append((c, value))
+
+
+def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
+ if c.default is not None:
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and \
+ (not c.default.optional or
+ not compiler.dialect.sequences_optional):
+ proc = compiler.process(c.default, **kw)
+ values.append((c, proc))
+ compiler.returning.append(c)
+ elif c.default.is_clause_element:
+ values.append(
+ (c, compiler.process(
+ c.default.arg.self_group(), **kw))
+ )
+ compiler.returning.append(c)
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None))
+ )
+ compiler.prefetch.append(c)
+ else:
+ compiler.returning.append(c)
+
+
+def _append_param_insert_pk(compiler, stmt, c, values, kw):
+ if (
+ (c.default is not None and
+ (not c.default.is_sequence or
+ compiler.dialect.supports_sequences)) or
+ c is stmt.table._autoincrement_column and
+ (compiler.dialect.supports_sequences or
+ compiler.dialect.
+ preexecute_autoincrement_sequences)
+ ):
+ values.append(
+ (c, _create_bind_param(compiler, c, None))
+ )
+
+ compiler.prefetch.append(c)
+
+
+def _append_param_insert_hasdefault(
+ compiler, stmt, c, implicit_return_defaults, values, kw):
+
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and \
+ (not c.default.optional or
+ not compiler.dialect.sequences_optional):
+ proc = compiler.process(c.default, **kw)
+ values.append((c, proc))
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ compiler.postfetch.append(c)
+ elif c.default.is_clause_element:
+ proc = compiler.process(c.default.arg.self_group(), **kw)
+ values.append((c, proc))
+
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ # don't add primary key column to postfetch
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None))
+ )
+ compiler.prefetch.append(c)
+
+
+def _append_param_insert_select_hasdefault(
+ compiler, stmt, c, values, kw):
+
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and \
+ (not c.default.optional or
+ not compiler.dialect.sequences_optional):
+ proc = c.default
+ values.append((c, proc))
+ elif c.default.is_clause_element:
+ proc = c.default.arg.self_group()
+ values.append((c, proc))
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None, process=False))
+ )
+ compiler.prefetch.append(c)
+
+
+def _append_param_update(
+ compiler, stmt, c, implicit_return_defaults, values, kw):
+
+ if c.onupdate is not None and not c.onupdate.is_sequence:
+ if c.onupdate.is_clause_element:
+ values.append(
+ (c, compiler.process(
+ c.onupdate.arg.self_group(), **kw))
+ )
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ else:
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None))
+ )
+ compiler.prefetch.append(c)
+ elif c.server_onupdate is not None:
+ if implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+ else:
+ compiler.postfetch.append(c)
+ elif implicit_return_defaults and \
+ c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+
+def _get_multitable_params(
+ compiler, stmt, stmt_parameters, check_columns,
+ _col_bind_name, _getattr_col_key, values, kw):
+
+ normalized_params = dict(
+ (elements._clause_element_as_expr(c), param)
+ for c, param in stmt_parameters.items()
+ )
+ affected_tables = set()
+ for t in stmt._extra_froms:
+ for c in t.c:
+ if c in normalized_params:
+ affected_tables.add(t)
+ check_columns[_getattr_col_key(c)] = c
+ value = normalized_params[c]
+ if elements._is_literal(value):
+ value = _create_bind_param(
+ compiler, c, value, required=value is REQUIRED,
+ name=_col_bind_name(c))
+ else:
+ compiler.postfetch.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ values.append((c, value))
+ # determine tables which are actually to be updated - process onupdate
+ # and server_onupdate for these
+ for t in affected_tables:
+ for c in t.c:
+ if c in normalized_params:
+ continue
+ elif (c.onupdate is not None and not
+ c.onupdate.is_sequence):
+ if c.onupdate.is_clause_element:
+ values.append(
+ (c, compiler.process(
+ c.onupdate.arg.self_group(),
+ **kw)
+ )
+ )
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (c, _create_bind_param(
+ compiler, c, None, name=_col_bind_name(c)
+ )
+ )
+ )
+ compiler.prefetch.append(c)
+ elif c.server_onupdate is not None:
+ compiler.postfetch.append(c)
+
+
+def _extend_values_for_multiparams(compiler, stmt, values, kw):
+ values_0 = values
+ values = [values]
+
+ values.extend(
+ [
+ (
+ c,
+ (_create_bind_param(
+ compiler, c, row[c.key],
+ name="%s_%d" % (c.key, i + 1)
+ ) if elements._is_literal(row[c.key])
+ else compiler.process(
+ row[c.key].self_group(), **kw))
+ if c.key in row else param
+ )
+ for (c, param) in values_0
+ ]
+ for i, row in enumerate(stmt.parameters[1:])
+ )
+ return values
+
+
+def _get_stmt_parameters_params(
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw):
+ for k, v in stmt_parameters.items():
+ colkey = _column_as_key(k)
+ if colkey is not None:
+ parameters.setdefault(colkey, v)
+ else:
+ # a non-Column expression on the left side;
+ # add it to values() in an "as-is" state,
+ # coercing right side to bound param
+ if elements._is_literal(v):
+ v = compiler.process(
+ elements.BindParameter(None, v, type_=k.type),
+ **kw)
+ else:
+ v = compiler.process(v.self_group(), **kw)
+
+ values.append((k, v))
+
+
+def _get_returning_modifiers(compiler, stmt):
+ need_pks = compiler.isinsert and \
+ not compiler.inline and \
+ not stmt._returning and \
+ not stmt._has_multi_parameters
+
+ implicit_returning = need_pks and \
+ compiler.dialect.implicit_returning and \
+ stmt.table.implicit_returning
+
+ if compiler.isinsert:
+ implicit_return_defaults = (implicit_returning and
+ stmt._return_defaults)
+ elif compiler.isupdate:
+ implicit_return_defaults = (compiler.dialect.implicit_returning and
+ stmt.table.implicit_returning and
+ stmt._return_defaults)
+ else:
+ implicit_return_defaults = False
+
+ if implicit_return_defaults:
+ if stmt._return_defaults is True:
+ implicit_return_defaults = set(stmt.table.c)
+ else:
+ implicit_return_defaults = set(stmt._return_defaults)
+
+ postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
+
+ return need_pks, implicit_returning, \
+ implicit_return_defaults, postfetch_lastrowid
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 1934d0776..9f2ce7ce3 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -475,6 +475,7 @@ class Insert(ValuesBase):
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self.select = self.select_names = None
+ self.include_insert_from_select_defaults = False
self.inline = inline
self._returning = returning
self._validate_dialect_kwargs(dialect_kw)
@@ -487,7 +488,7 @@ class Insert(ValuesBase):
return ()
@_generative
- def from_select(self, names, select):
+ def from_select(self, names, select, include_defaults=True):
"""Return a new :class:`.Insert` construct which represents
an ``INSERT...FROM SELECT`` statement.
@@ -506,6 +507,21 @@ class Insert(ValuesBase):
is not checked before passing along to the database, the database
would normally raise an exception if these column lists don't
correspond.
+ :param include_defaults: if True, non-server default values and
+ SQL expressions as specified on :class:`.Column` objects
+ (as documented in :ref:`metadata_defaults_toplevel`) not
+ otherwise specified in the list of names will be rendered
+ into the INSERT and SELECT statements, so that these values are also
+ included in the data to be inserted.
+
+ .. note:: A Python-side default that uses a Python callable function
+ will only be invoked **once** for the whole statement, and **not
+ per row**.
+
+ .. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders
+ Python-side and SQL expression column defaults into the
+ SELECT statement for columns otherwise not included in the
+ list of column names.
.. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
implies that the :paramref:`.insert.inline` flag is set to
@@ -514,13 +530,6 @@ class Insert(ValuesBase):
deals with an arbitrary number of rows, so the
:attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
- .. note::
-
- A SELECT..INSERT construct in SQL has no VALUES clause. Therefore
- :class:`.Column` objects which utilize Python-side defaults
- (e.g. as described at :ref:`metadata_defaults_toplevel`)
- will **not** take effect when using :meth:`.Insert.from_select`.
-
.. versionadded:: 0.8.3
"""
@@ -533,6 +542,7 @@ class Insert(ValuesBase):
self.select_names = names
self.inline = True
+ self.include_insert_from_select_defaults = include_defaults
self.select = _interpret_as_select(select)
def _copy_internals(self, clone=_clone, **kw):
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 8ec0aa700..fa9b66024 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -228,6 +228,7 @@ class ClauseElement(Visitable):
is_selectable = False
is_clause_element = True
+ description = None
_order_by_label_element = None
_is_from_container = False
@@ -540,7 +541,7 @@ class ClauseElement(Visitable):
__nonzero__ = __bool__
def __repr__(self):
- friendly = getattr(self, 'description', None)
+ friendly = self.description
if friendly is None:
return object.__repr__(self)
else:
@@ -860,6 +861,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
expressions and function calls.
"""
+ while self._is_clone_of is not None:
+ self = self._is_clone_of
+
return _anonymous_label(
'%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon'))
)
@@ -1616,10 +1620,10 @@ class Null(ColumnElement):
return type_api.NULLTYPE
@classmethod
- def _singleton(cls):
+ def _instance(cls):
"""Return a constant :class:`.Null` construct."""
- return NULL
+ return Null()
def compare(self, other):
return isinstance(other, Null)
@@ -1640,11 +1644,11 @@ class False_(ColumnElement):
return type_api.BOOLEANTYPE
def _negate(self):
- return TRUE
+ return True_()
@classmethod
- def _singleton(cls):
- """Return a constant :class:`.False_` construct.
+ def _instance(cls):
+ """Return a :class:`.False_` construct.
E.g.::
@@ -1678,7 +1682,7 @@ class False_(ColumnElement):
"""
- return FALSE
+ return False_()
def compare(self, other):
return isinstance(other, False_)
@@ -1699,17 +1703,17 @@ class True_(ColumnElement):
return type_api.BOOLEANTYPE
def _negate(self):
- return FALSE
+ return False_()
@classmethod
def _ifnone(cls, other):
if other is None:
- return cls._singleton()
+ return cls._instance()
else:
return other
@classmethod
- def _singleton(cls):
+ def _instance(cls):
"""Return a constant :class:`.True_` construct.
E.g.::
@@ -1744,15 +1748,11 @@ class True_(ColumnElement):
"""
- return TRUE
+ return True_()
def compare(self, other):
return isinstance(other, True_)
-NULL = Null()
-FALSE = False_()
-TRUE = True_()
-
class ClauseList(ClauseElement):
"""Describe a list of clauses, separated by an operator.
@@ -2782,6 +2782,10 @@ class Grouping(ColumnElement):
return self
@property
+ def _key_label(self):
+ return self._label
+
+ @property
def _label(self):
return getattr(self.element, '_label', None) or self.anon_label
@@ -2888,6 +2892,120 @@ class Over(ColumnElement):
))
+class FunctionFilter(ColumnElement):
+ """Represent a function FILTER clause.
+
+ This is a special operator against aggregate and window functions,
+ which controls which rows are passed to it.
+ It's supported only by certain database backends.
+
+ Invocation of :class:`.FunctionFilter` is via
+ :meth:`.FunctionElement.filter`::
+
+ func.count(1).filter(True)
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.FunctionElement.filter`
+
+ """
+ __visit_name__ = 'funcfilter'
+
+ criterion = None
+
+ def __init__(self, func, *criterion):
+ """Produce a :class:`.FunctionFilter` object against a function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ E.g.::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), MyClass.name == 'some name')
+
+ Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
+
+ This function is also available from the :data:`~.expression.func`
+ construct itself via the :meth:`.FunctionElement.filter` method.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.FunctionElement.filter`
+
+
+ """
+ self.func = func
+ self.filter(*criterion)
+
+ def filter(self, *criterion):
+ """Produce an additional FILTER against the function.
+
+ This method adds additional criteria to the initial criteria
+ set up by :meth:`.FunctionElement.filter`.
+
+ Multiple criteria are joined together at SQL render time
+ via ``AND``.
+
+
+ """
+
+ for criterion in list(criterion):
+ criterion = _expression_literal_as_text(criterion)
+
+ if self.criterion is not None:
+ self.criterion = self.criterion & criterion
+ else:
+ self.criterion = criterion
+
+ return self
+
+ def over(self, partition_by=None, order_by=None):
+ """Produce an OVER clause against this filtered function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ The expression::
+
+ func.rank().filter(MyClass.y > 5).over(order_by='x')
+
+ is shorthand for::
+
+ from sqlalchemy import over, funcfilter
+ over(funcfilter(func.rank(), MyClass.y > 5), order_by='x')
+
+ See :func:`~.expression.over` for a full description.
+
+ """
+ return Over(self, partition_by=partition_by, order_by=order_by)
+
+ @util.memoized_property
+ def type(self):
+ return self.func.type
+
+ def get_children(self, **kwargs):
+ return [c for c in
+ (self.func, self.criterion)
+ if c is not None]
+
+ def _copy_internals(self, clone=_clone, **kw):
+ self.func = clone(self.func, **kw)
+ if self.criterion is not None:
+ self.criterion = clone(self.criterion, **kw)
+
+ @property
+ def _from_objects(self):
+ return list(itertools.chain(
+ *[c._from_objects for c in (self.func, self.criterion)
+ if c is not None]
+ ))
+
+
class Label(ColumnElement):
"""Represents a column label (AS).
@@ -3491,7 +3609,7 @@ def _string_or_unprintable(element):
else:
try:
return str(element)
- except:
+ except Exception:
return "unprintable element %r" % element
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index d96f048b9..2ffc5468c 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -36,7 +36,7 @@ from .elements import ClauseElement, ColumnElement,\
True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \
Grouping, not_, \
collate, literal_column, between,\
- literal, outparam, type_coerce, ClauseList
+ literal, outparam, type_coerce, ClauseList, FunctionFilter
from .elements import SavepointClause, RollbackToSavepointClause, \
ReleaseSavepointClause
@@ -89,14 +89,16 @@ asc = public_factory(UnaryExpression._create_asc, ".expression.asc")
desc = public_factory(UnaryExpression._create_desc, ".expression.desc")
distinct = public_factory(
UnaryExpression._create_distinct, ".expression.distinct")
-true = public_factory(True_._singleton, ".expression.true")
-false = public_factory(False_._singleton, ".expression.false")
-null = public_factory(Null._singleton, ".expression.null")
+true = public_factory(True_._instance, ".expression.true")
+false = public_factory(False_._instance, ".expression.false")
+null = public_factory(Null._instance, ".expression.null")
join = public_factory(Join._create_join, ".expression.join")
outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
insert = public_factory(Insert, ".expression.insert")
update = public_factory(Update, ".expression.update")
delete = public_factory(Delete, ".expression.delete")
+funcfilter = public_factory(
+ FunctionFilter, ".expression.funcfilter")
# internal functions still being called from tests and the ORM,
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 7efb1e916..9280c7d60 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -12,7 +12,7 @@ from . import sqltypes, schema
from .base import Executable, ColumnCollection
from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
literal_column, _type_from_args, ColumnElement, _clone,\
- Over, BindParameter
+ Over, BindParameter, FunctionFilter
from .selectable import FromClause, Select, Alias
from . import operators
@@ -116,6 +116,35 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
return Over(self, partition_by=partition_by, order_by=order_by)
+ def filter(self, *criterion):
+ """Produce a FILTER clause against this function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ The expression::
+
+ func.count(1).filter(True)
+
+ is shorthand for::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), True)
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :class:`.FunctionFilter`
+
+ :func:`.funcfilter`
+
+
+ """
+ if not criterion:
+ return self
+ return FunctionFilter(self, *criterion)
+
@property
def _from_objects(self):
return self.clauses._from_objects
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index d9fd37f92..96cabbf4f 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -412,8 +412,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
table.dispatch.after_parent_attach(table, metadata)
return table
except:
- metadata._remove_table(name, schema)
- raise
+ with util.safe_reraise():
+ metadata._remove_table(name, schema)
@property
@util.deprecated('0.9', 'Use ``table.schema.quote``')
@@ -728,7 +728,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
checkfirst=checkfirst)
def tometadata(self, metadata, schema=RETAIN_SCHEMA,
- referred_schema_fn=None):
+ referred_schema_fn=None, name=None):
"""Return a copy of this :class:`.Table` associated with a different
:class:`.MetaData`.
@@ -785,13 +785,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
.. versionadded:: 0.9.2
- """
+ :param name: optional string name indicating the target table name.
+ If not specified or None, the table name is retained. This allows
+ a :class:`.Table` to be copied to the same :class:`.MetaData` target
+ with a new name.
+
+ .. versionadded:: 1.0.0
+ """
+ if name is None:
+ name = self.name
if schema is RETAIN_SCHEMA:
schema = self.schema
elif schema is None:
schema = metadata.schema
- key = _get_table_key(self.name, schema)
+ key = _get_table_key(name, schema)
if key in metadata.tables:
util.warn("Table '%s' already exists within the given "
"MetaData - not copying." % self.description)
@@ -801,7 +809,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
for c in self.columns:
args.append(c.copy(schema=schema))
table = Table(
- self.name, metadata, schema=schema,
+ name, metadata, schema=schema,
*args, **self.kwargs
)
for c in self.constraints:
@@ -1061,8 +1069,8 @@ class Column(SchemaItem, ColumnClause):
conditionally rendered differently on different backends,
consider custom compilation rules for :class:`.CreateColumn`.
- ..versionadded:: 0.8.3 Added the ``system=True`` parameter to
- :class:`.Column`.
+ .. versionadded:: 0.8.3 Added the ``system=True`` parameter to
+ :class:`.Column`.
"""
@@ -1222,8 +1230,10 @@ class Column(SchemaItem, ColumnClause):
existing = getattr(self, 'table', None)
if existing is not None and existing is not table:
raise exc.ArgumentError(
- "Column object already assigned to Table '%s'" %
- existing.description)
+ "Column object '%s' already assigned to Table '%s'" % (
+ self.key,
+ existing.description
+ ))
if self.key in table._columns:
col = table._columns.get(self.key)
@@ -1547,7 +1557,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
)
return self._schema_item_copy(fk)
- def _get_colspec(self, schema=None):
+ def _get_colspec(self, schema=None, table_name=None):
"""Return a string based 'column specification' for this
:class:`.ForeignKey`.
@@ -1557,7 +1567,15 @@ class ForeignKey(DialectKWArgs, SchemaItem):
"""
if schema:
_schema, tname, colname = self._column_tokens
+ if table_name is not None:
+ tname = table_name
return "%s.%s.%s" % (schema, tname, colname)
+ elif table_name:
+ schema, tname, colname = self._column_tokens
+ if schema:
+ return "%s.%s.%s" % (schema, table_name, colname)
+ else:
+ return "%s.%s" % (table_name, colname)
elif self._table_column is not None:
return "%s.%s" % (
self._table_column.table.fullname, self._table_column.key)
@@ -2649,10 +2667,15 @@ class ForeignKeyConstraint(Constraint):
event.listen(table.metadata, "before_drop",
ddl.DropConstraint(self, on=supports_alter))
- def copy(self, schema=None, **kw):
+ def copy(self, schema=None, target_table=None, **kw):
fkc = ForeignKeyConstraint(
[x.parent.key for x in self._elements.values()],
- [x._get_colspec(schema=schema)
+ [x._get_colspec(
+ schema=schema,
+ table_name=target_table.name
+ if target_table is not None
+ and x._table_key() == x.parent.table.key
+ else None)
for x in self._elements.values()],
name=self.name,
onupdate=self.onupdate,
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 9e8cb3bc5..8198a6733 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -746,6 +746,33 @@ class Join(FromClause):
providing a "natural join".
"""
+ constraints = cls._joincond_scan_left_right(
+ a, a_subset, b, consider_as_foreign_keys)
+
+ if len(constraints) > 1:
+ cls._joincond_trim_constraints(
+ a, b, constraints, consider_as_foreign_keys)
+
+ if len(constraints) == 0:
+ if isinstance(b, FromGrouping):
+ hint = " Perhaps you meant to convert the right side to a "\
+ "subquery using alias()?"
+ else:
+ hint = ""
+ raise exc.NoForeignKeysError(
+ "Can't find any foreign key relationships "
+ "between '%s' and '%s'.%s" %
+ (a.description, b.description, hint))
+
+ crit = [(x == y) for x, y in list(constraints.values())[0]]
+ if len(crit) == 1:
+ return (crit[0])
+ else:
+ return and_(*crit)
+
+ @classmethod
+ def _joincond_scan_left_right(
+ cls, a, a_subset, b, consider_as_foreign_keys):
constraints = collections.defaultdict(list)
for left in (a_subset, a):
@@ -780,57 +807,41 @@ class Join(FromClause):
if nrte.table_name == b.name:
raise
else:
- # this is totally covered. can't get
- # coverage to mark it.
continue
if col is not None:
constraints[fk.constraint].append((col, fk.parent))
if constraints:
break
+ return constraints
+ @classmethod
+ def _joincond_trim_constraints(
+ cls, a, b, constraints, consider_as_foreign_keys):
+ # more than one constraint matched. narrow down the list
+ # to include just those FKCs that match exactly to
+ # "consider_as_foreign_keys".
+ if consider_as_foreign_keys:
+ for const in list(constraints):
+ if set(f.parent for f in const.elements) != set(
+ consider_as_foreign_keys):
+ del constraints[const]
+
+ # if still multiple constraints, but
+ # they all refer to the exact same end result, use it.
if len(constraints) > 1:
- # more than one constraint matched. narrow down the list
- # to include just those FKCs that match exactly to
- # "consider_as_foreign_keys".
- if consider_as_foreign_keys:
- for const in list(constraints):
- if set(f.parent for f in const.elements) != set(
- consider_as_foreign_keys):
- del constraints[const]
-
- # if still multiple constraints, but
- # they all refer to the exact same end result, use it.
- if len(constraints) > 1:
- dedupe = set(tuple(crit) for crit in constraints.values())
- if len(dedupe) == 1:
- key = list(constraints)[0]
- constraints = {key: constraints[key]}
-
- if len(constraints) != 1:
- raise exc.AmbiguousForeignKeysError(
- "Can't determine join between '%s' and '%s'; "
- "tables have more than one foreign key "
- "constraint relationship between them. "
- "Please specify the 'onclause' of this "
- "join explicitly." % (a.description, b.description))
-
- if len(constraints) == 0:
- if isinstance(b, FromGrouping):
- hint = " Perhaps you meant to convert the right side to a "\
- "subquery using alias()?"
- else:
- hint = ""
- raise exc.NoForeignKeysError(
- "Can't find any foreign key relationships "
- "between '%s' and '%s'.%s" %
- (a.description, b.description, hint))
-
- crit = [(x == y) for x, y in list(constraints.values())[0]]
- if len(crit) == 1:
- return (crit[0])
- else:
- return and_(*crit)
+ dedupe = set(tuple(crit) for crit in constraints.values())
+ if len(dedupe) == 1:
+ key = list(constraints)[0]
+ constraints = {key: constraints[key]}
+
+ if len(constraints) != 1:
+ raise exc.AmbiguousForeignKeysError(
+ "Can't determine join between '%s' and '%s'; "
+ "tables have more than one foreign key "
+ "constraint relationship between them. "
+ "Please specify the 'onclause' of this "
+ "join explicitly." % (a.description, b.description))
def select(self, whereclause=None, **kwargs):
"""Create a :class:`.Select` from this :class:`.Join`.
@@ -2153,6 +2164,7 @@ class Select(HasPrefixes, GenerativeSelect):
_prefixes = ()
_hints = util.immutabledict()
+ _statement_hints = ()
_distinct = False
_from_cloned = None
_correlate = ()
@@ -2525,10 +2537,30 @@ class Select(HasPrefixes, GenerativeSelect):
return self._get_display_froms()
+ def with_statement_hint(self, text, dialect_name='*'):
+ """add a statement hint to this :class:`.Select`.
+
+ This method is similar to :meth:`.Select.with_hint` except that
+ it does not require an individual table, and instead applies to the
+ statement as a whole.
+
+ Hints here are specific to the backend database and may include
+ directives such as isolation levels, file directives, fetch directives,
+ etc.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.Select.with_hint`
+
+ """
+ return self.with_hint(None, text, dialect_name)
+
@_generative
def with_hint(self, selectable, text, dialect_name='*'):
- """Add an indexing hint for the given selectable to this
- :class:`.Select`.
+ """Add an indexing or other executional context hint for the given
+ selectable to this :class:`.Select`.
The text of the hint is rendered in the appropriate
location for the database backend in use, relative
@@ -2540,7 +2572,7 @@ class Select(HasPrefixes, GenerativeSelect):
following::
select([mytable]).\\
- with_hint(mytable, "+ index(%(name)s ix_mytable)")
+ with_hint(mytable, "index(%(name)s ix_mytable)")
Would render SQL as::
@@ -2551,13 +2583,19 @@ class Select(HasPrefixes, GenerativeSelect):
and Sybase simultaneously::
select([mytable]).\\
- with_hint(
- mytable, "+ index(%(name)s ix_mytable)", 'oracle').\\
+ with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\\
with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
+ .. seealso::
+
+ :meth:`.Select.with_statement_hint`
+
"""
- self._hints = self._hints.union(
- {(selectable, dialect_name): text})
+ if selectable is None:
+ self._statement_hints += ((dialect_name, text), )
+ else:
+ self._hints = self._hints.union(
+ {(selectable, dialect_name): text})
@property
def type(self):
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index 67c13231e..1284f9c2a 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -37,8 +37,6 @@ class ConnectionKiller(object):
def _safe(self, fn):
try:
fn()
- except (SystemExit, KeyboardInterrupt):
- raise
except Exception as e:
warnings.warn(
"testing_reaper couldn't "
@@ -168,8 +166,6 @@ class ReconnectFixture(object):
def _safe(self, fn):
try:
fn()
- except (SystemExit, KeyboardInterrupt):
- raise
except Exception as e:
warnings.warn(
"ReconnectFixture couldn't "
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
index 283d89e36..f94724608 100644
--- a/lib/sqlalchemy/testing/exclusions.py
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -133,7 +133,7 @@ class compound(object):
name, fail._as_string(config), str(ex))))
break
else:
- raise ex
+ util.raise_from_cause(ex)
def _expect_success(self, config, name='block'):
if not self.fails:
@@ -178,8 +178,7 @@ class Predicate(object):
@classmethod
def as_predicate(cls, predicate, description=None):
if isinstance(predicate, compound):
- return cls.as_predicate(predicate.fails.union(predicate.skips))
-
+ return cls.as_predicate(predicate.enabled_for_config, description)
elif isinstance(predicate, Predicate):
if description and predicate.description is None:
predicate.description = description
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index 0bcdad959..c8f7fdf30 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -120,7 +120,7 @@ def _pg_create_db(cfg, eng, ident):
isolation_level="AUTOCOMMIT") as conn:
try:
_pg_drop_db(cfg, conn, ident)
- except:
+ except Exception:
pass
currentdb = conn.scalar("select current_database()")
conn.execute("CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb))
@@ -131,7 +131,7 @@ def _mysql_create_db(cfg, eng, ident):
with eng.connect() as conn:
try:
_mysql_drop_db(cfg, conn, ident)
- except:
+ except Exception:
pass
conn.execute("CREATE DATABASE %s" % ident)
conn.execute("CREATE DATABASE %s_test_schema" % ident)
@@ -173,15 +173,15 @@ def _mysql_drop_db(cfg, eng, ident):
with eng.connect() as conn:
try:
conn.execute("DROP DATABASE %s_test_schema" % ident)
- except:
+ except Exception:
pass
try:
conn.execute("DROP DATABASE %s_test_schema_2" % ident)
- except:
+ except Exception:
pass
try:
conn.execute("DROP DATABASE %s" % ident)
- except:
+ except Exception:
pass
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index a04bcbbdd..da3e3128a 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -314,6 +314,20 @@ class SuiteRequirements(Requirements):
return exclusions.open()
@property
+ def temp_table_reflection(self):
+ return exclusions.open()
+
+ @property
+ def temp_table_names(self):
+ """target dialect supports listing of temporary table names"""
+ return exclusions.closed()
+
+ @property
+ def temporary_views(self):
+ """target database supports temporary views"""
+ return exclusions.closed()
+
+ @property
def index_reflection(self):
return exclusions.open()
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
index 92d3d93e5..38519dfb9 100644
--- a/lib/sqlalchemy/testing/suite/test_insert.py
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -4,7 +4,7 @@ from .. import exclusions
from ..assertions import eq_
from .. import engines
-from sqlalchemy import Integer, String, select, util
+from sqlalchemy import Integer, String, select, literal_column, literal
from ..schema import Table, Column
@@ -90,6 +90,13 @@ class InsertBehaviorTest(fixtures.TablesTest):
Column('id', Integer, primary_key=True, autoincrement=False),
Column('data', String(50))
)
+ Table('includes_defaults', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('data', String(50)),
+ Column('x', Integer, default=5),
+ Column('y', Integer,
+ default=literal_column("2", type_=Integer) + literal(2)))
def test_autoclose_on_insert(self):
if requirements.returning.enabled:
@@ -158,6 +165,34 @@ class InsertBehaviorTest(fixtures.TablesTest):
("data3", ), ("data3", )]
)
+ @requirements.insert_from_select
+ def test_insert_from_select_with_defaults(self):
+ table = self.tables.includes_defaults
+ config.db.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ]
+ )
+
+ config.db.execute(
+ table.insert(inline=True).
+ from_select(("id", "data",),
+ select([table.c.id + 5, table.c.data]).
+ where(table.c.data.in_(["data2", "data3"]))
+ ),
+ )
+
+ eq_(
+ config.db.execute(
+ select([table]).order_by(table.c.data, table.c.id)
+ ).fetchall(),
+ [(1, 'data1', 5, 4), (2, 'data2', 5, 4),
+ (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
+ )
+
class ReturningTest(fixtures.TablesTest):
run_create_tables = 'each'
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 575a38db9..08b858b47 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -95,6 +95,39 @@ class ComponentReflectionTest(fixtures.TablesTest):
cls.define_index(metadata, users)
if testing.requires.view_column_reflection.enabled:
cls.define_views(metadata, schema)
+ if not schema and testing.requires.temp_table_reflection.enabled:
+ cls.define_temp_tables(metadata)
+
+ @classmethod
+ def define_temp_tables(cls, metadata):
+ # cheat a bit, we should fix this with some dialect-level
+ # temp table fixture
+ if testing.against("oracle"):
+ kw = {
+ 'prefixes': ["GLOBAL TEMPORARY"],
+ 'oracle_on_commit': 'PRESERVE ROWS'
+ }
+ else:
+ kw = {
+ 'prefixes': ["TEMPORARY"],
+ }
+
+ user_tmp = Table(
+ "user_tmp", metadata,
+ Column("id", sa.INT, primary_key=True),
+ Column('name', sa.VARCHAR(50)),
+ Column('foo', sa.INT),
+ sa.UniqueConstraint('name', name='user_tmp_uq'),
+ sa.Index("user_tmp_ix", "foo"),
+ **kw
+ )
+ if testing.requires.view_reflection.enabled and \
+ testing.requires.temporary_views.enabled:
+ event.listen(
+ user_tmp, "after_create",
+ DDL("create temporary view user_tmp_v as "
+ "select * from user_tmp")
+ )
@classmethod
def define_index(cls, metadata, users):
@@ -147,6 +180,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
users, addresses, dingalings = self.tables.users, \
self.tables.email_addresses, self.tables.dingalings
insp = inspect(meta.bind)
+
if table_type == 'view':
table_names = insp.get_view_names(schema)
table_names.sort()
@@ -162,6 +196,20 @@ class ComponentReflectionTest(fixtures.TablesTest):
answer = ['dingalings', 'email_addresses', 'users']
eq_(sorted(table_names), answer)
+ @testing.requires.temp_table_names
+ def test_get_temp_table_names(self):
+ insp = inspect(testing.db)
+ temp_table_names = insp.get_temp_table_names()
+ eq_(sorted(temp_table_names), ['user_tmp'])
+
+ @testing.requires.view_reflection
+ @testing.requires.temp_table_names
+ @testing.requires.temporary_views
+ def test_get_temp_view_names(self):
+ insp = inspect(self.metadata.bind)
+ temp_table_names = insp.get_temp_view_names()
+ eq_(sorted(temp_table_names), ['user_tmp_v'])
+
@testing.requires.table_reflection
def test_get_table_names(self):
self._test_get_table_names()
@@ -294,6 +342,28 @@ class ComponentReflectionTest(fixtures.TablesTest):
def test_get_columns_with_schema(self):
self._test_get_columns(schema=testing.config.test_schema)
+ @testing.requires.temp_table_reflection
+ def test_get_temp_table_columns(self):
+ meta = MetaData(testing.db)
+ user_tmp = self.tables.user_tmp
+ insp = inspect(meta.bind)
+ cols = insp.get_columns('user_tmp')
+ self.assert_(len(cols) > 0, len(cols))
+
+ for i, col in enumerate(user_tmp.columns):
+ eq_(col.name, cols[i]['name'])
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.view_column_reflection
+ @testing.requires.temporary_views
+ def test_get_temp_view_columns(self):
+ insp = inspect(self.metadata.bind)
+ cols = insp.get_columns('user_tmp_v')
+ eq_(
+ [col['name'] for col in cols],
+ ['id', 'name', 'foo']
+ )
+
@testing.requires.view_column_reflection
def test_get_view_columns(self):
self._test_get_columns(table_type='view')
@@ -426,6 +496,28 @@ class ComponentReflectionTest(fixtures.TablesTest):
def test_get_unique_constraints(self):
self._test_get_unique_constraints()
+ @testing.requires.temp_table_reflection
+ @testing.requires.unique_constraint_reflection
+ def test_get_temp_table_unique_constraints(self):
+ insp = inspect(self.metadata.bind)
+ reflected = insp.get_unique_constraints('user_tmp')
+ for refl in reflected:
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ refl.pop('duplicates_index', None)
+ eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}])
+
+ @testing.requires.temp_table_reflection
+ def test_get_temp_table_indexes(self):
+ insp = inspect(self.metadata.bind)
+ indexes = insp.get_indexes('user_tmp')
+ eq_(
+ # TODO: we need to add better filtering for indexes/uq constraints
+ # that are doubled up
+ [idx for idx in indexes if idx['name'] == 'user_tmp_ix'],
+ [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}]
+ )
+
@testing.requires.unique_constraint_reflection
@testing.requires.schemas
def test_get_unique_constraints_with_schema(self):
@@ -466,6 +558,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
)
for orig, refl in zip(uniques, reflected):
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ refl.pop('duplicates_index', None)
eq_(orig, refl)
@testing.provide_metadata
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index c963b18c3..dfed5b90a 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -33,7 +33,8 @@ from .langhelpers import iterate_attributes, class_hierarchy, \
duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\
classproperty, set_creation_order, warn_exception, warn, NoneType,\
constructor_copy, methods_equivalent, chop_traceback, asint,\
- generic_repr, counter, PluginLoader, hybridmethod, safe_reraise,\
+ generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \
+ safe_reraise,\
get_callable_argspec, only_once, attrsetter, ellipses_string, \
warn_limited
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 76f85f605..5c17bea88 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -134,7 +134,8 @@ def public_factory(target, location):
fn = target.__init__
callable_ = target
doc = "Construct a new :class:`.%s` object. \n\n"\
- "This constructor is mirrored as a public API function; see :func:`~%s` "\
+ "This constructor is mirrored as a public API function; "\
+ "see :func:`~%s` "\
"for a full usage and argument description." % (
target.__name__, location, )
else:
@@ -155,6 +156,7 @@ def %(name)s(%(args)s):
exec(code, env)
decorated = env[location_name]
decorated.__doc__ = fn.__doc__
+ decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0]
if compat.py2k or hasattr(fn, '__func__'):
fn.__func__.__doc__ = doc
else:
@@ -490,7 +492,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
val = getattr(obj, arg, missing)
if val is not missing and val != defval:
output.append('%s=%r' % (arg, val))
- except:
+ except Exception:
pass
if additional_kw:
@@ -499,7 +501,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
val = getattr(obj, arg, missing)
if val is not missing and val != defval:
output.append('%s=%r' % (arg, val))
- except:
+ except Exception:
pass
return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
@@ -1090,10 +1092,23 @@ class classproperty(property):
return desc.fget(cls)
+class hybridproperty(object):
+ def __init__(self, func):
+ self.func = func
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ clsval = self.func(owner)
+ clsval.__doc__ = self.func.__doc__
+ return clsval
+ else:
+ return self.func(instance)
+
+
class hybridmethod(object):
"""Decorate a function as cls- or instance- level."""
- def __init__(self, func, expr=None):
+ def __init__(self, func):
self.func = func
def __get__(self, instance, owner):
@@ -1185,7 +1200,7 @@ def warn_exception(func, *args, **kwargs):
"""
try:
return func(*args, **kwargs)
- except:
+ except Exception:
warn("%s('%s') ignored" % sys.exc_info()[0:2])