diff options
Diffstat (limited to 'lib/sqlalchemy')
48 files changed, 2479 insertions, 2651 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 4e00437ea..d8534dd49 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -4,11 +4,9 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import inspect +import inspect as _inspect import sys -import sqlalchemy.exc as exceptions - from sqlalchemy.sql import ( alias, and_, @@ -111,15 +109,17 @@ from sqlalchemy.schema import ( UniqueConstraint, ) +from sqlalchemy.inspection import inspect + from sqlalchemy.engine import create_engine, engine_from_config __all__ = sorted(name for name, obj in locals().items() - if not (name.startswith('_') or inspect.ismodule(obj))) + if not (name.startswith('_') or _inspect.ismodule(obj))) -__version__ = '0.7.7' +__version__ = '0.8.0b1' -del inspect, sys +del _inspect, sys from sqlalchemy import util as _sa_util _sa_util.importlater.resolve_all() diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 2d4832412..16eb32e21 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -17,3 +17,31 @@ __all__ = ( 'sqlite', 'sybase', ) + +from sqlalchemy import util + +def _auto_fn(name): + """default dialect importer. + + plugs into the :class:`.PluginLoader` + as a first-hit system. + + """ + if "." in name: + dialect, driver = name.split(".") + else: + dialect = name + driver = "base" + try: + module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + except ImportError: + return None + + module = getattr(module, dialect) + if hasattr(module, driver): + module = getattr(module, driver) + return lambda: module.dialect + else: + return None + +registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/drizzle/__init__.py b/lib/sqlalchemy/dialects/drizzle/__init__.py index bbd716f59..1392b8e28 100644 --- a/lib/sqlalchemy/dialects/drizzle/__init__.py +++ b/lib/sqlalchemy/dialects/drizzle/__init__.py @@ -1,18 +1,22 @@ from sqlalchemy.dialects.drizzle import base, mysqldb -# default dialect base.dialect = mysqldb.dialect from sqlalchemy.dialects.drizzle.base import \ - BIGINT, BINARY, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \ - DECIMAL, DOUBLE, ENUM, \ - FLOAT, INTEGER, \ - NUMERIC, REAL, TEXT, TIME, TIMESTAMP, \ - VARBINARY, VARCHAR, dialect - + BIGINT, BINARY, BLOB, \ + BOOLEAN, CHAR, DATE, \ + DATETIME, DECIMAL, DOUBLE, \ + ENUM, FLOAT, INTEGER, \ + NUMERIC, REAL, TEXT, \ + TIME, TIMESTAMP, VARBINARY, \ + VARCHAR, dialect + __all__ = ( -'BIGINT', 'BINARY', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE', -'ENUM', 'FLOAT', 'INTEGER', -'NUMERIC', 'SET', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP', -'VARBINARY', 'VARCHAR', 'dialect' + 'BIGINT', 'BINARY', 'BLOB', + 'BOOLEAN', 'CHAR', 'DATE', + 'DATETIME', 'DECIMAL', 'DOUBLE', + 'ENUM', 'FLOAT', 'INTEGER', + 'NUMERIC', 'REAL', 'TEXT', + 'TIME', 'TIMESTAMP', 'VARBINARY', + 'VARCHAR', 'dialect' ) diff --git a/lib/sqlalchemy/dialects/drizzle/base.py b/lib/sqlalchemy/dialects/drizzle/base.py index 62967174f..94d2711a0 100644 --- a/lib/sqlalchemy/dialects/drizzle/base.py +++ b/lib/sqlalchemy/dialects/drizzle/base.py @@ -5,134 +5,32 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Support for the Drizzle database. - -Supported Versions and Features -------------------------------- - -SQLAlchemy supports the Drizzle database starting with 2010.08. -with capabilities increasing with more modern servers. -Most available DBAPI drivers are supported; see below. +"""Support for the Drizzle database. -===================================== =============== -Feature Minimum Version -===================================== =============== -sqlalchemy.orm 2010.08 -Table Reflection 2010.08 -DDL Generation 2010.08 -utf8/Full Unicode Connections 2010.08 -Transactions 2010.08 -Two-Phase Transactions 2010.08 -Nested Transactions 2010.08 -===================================== =============== +Drizzle is a variant of MySQL. Unlike MySQL, Drizzle's default storage engine +is InnoDB (transactions, foreign-keys) rather than MyISAM. For more +`Notable Differences <http://docs.drizzle.org/mysql_differences.html>`_, visit +the `Drizzle Documentation <http://docs.drizzle.org/index.html>`_. -See the official Drizzle documentation for detailed information about features -supported in any given server release. +The SQLAlchemy Drizzle dialect leans heavily on the MySQL dialect, so much of +the :doc:`SQLAlchemy MySQL <mysql>` documentation is also relevant. Connecting ---------- -See the API documentation on individual drivers for details on connecting. - -Connection Timeouts -------------------- - -Drizzle features an automatic connection close behavior, for connections that -have been idle for eight hours or more. To circumvent having this issue, use -the ``pool_recycle`` option which controls the maximum age of any connection:: - - engine = create_engine('drizzle+mysqldb://...', pool_recycle=3600) - -Storage Engines ---------------- - -Drizzle defaults to the ``InnoDB`` storage engine, which is transactional. - -Storage engines can be elected when creating tables in SQLAlchemy by supplying -a ``drizzle_engine='whatever'`` to the ``Table`` constructor. Any Drizzle table -creation option can be specified in this syntax:: - - Table('mytable', metadata, - Column('data', String(32)), - drizzle_engine='InnoDB', - ) - -Keys ----- - -Not all Drizzle storage engines support foreign keys. For ``BlitzDB`` and -similar engines, the information loaded by table reflection will not include -foreign keys. For these tables, you may supply a -:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: - - Table('mytable', metadata, - ForeignKeyConstraint(['other_id'], ['othertable.other_id']), - autoload=True - ) - -When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on -an integer primary key column:: - - >>> t = Table('mytable', metadata, - ... Column('mytable_id', Integer, primary_key=True) - ... ) - >>> t.create() - CREATE TABLE mytable ( - id INTEGER NOT NULL AUTO_INCREMENT, - PRIMARY KEY (id) - ) - -You can disable this behavior by supplying ``autoincrement=False`` to the -:class:`~sqlalchemy.Column`. This flag can also be used to enable -auto-increment on a secondary column in a multi-column key for some storage -engines:: - - Table('mytable', metadata, - Column('gid', Integer, primary_key=True, autoincrement=False), - Column('id', Integer, primary_key=True) - ) - -Drizzle SQL Extensions ----------------------- - -Many of the Drizzle SQL extensions are handled through SQLAlchemy's generic -function and operator support:: - - table.select(table.c.password==func.md5('plaintext')) - table.select(table.c.username.op('regexp')('^[a-d]')) - -And of course any valid Drizzle statement can be executed as a string as well. - -Some limited direct support for Drizzle extensions to SQL is currently -available. - -* SELECT pragma:: - - select(..., prefixes=['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) - -* UPDATE with LIMIT:: - - update(..., drizzle_limit=10) +See the individual driver sections below for details on connecting. """ -import datetime, inspect, re, sys - -from sqlalchemy import schema as sa_schema -from sqlalchemy import exc, log, sql, util -from sqlalchemy.sql import operators as sql_operators -from sqlalchemy.sql import functions as sql_functions -from sqlalchemy.sql import compiler -from array import array as _array - -from sqlalchemy.engine import reflection -from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import exc +from sqlalchemy import log from sqlalchemy import types as sqltypes +from sqlalchemy.engine import reflection from sqlalchemy.dialects.mysql import base as mysql_dialect - from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME, \ - BLOB, BINARY, VARBINARY + BLOB, BINARY, VARBINARY + class _NumericType(object): """Base for Drizzle numeric types.""" @@ -140,6 +38,7 @@ class _NumericType(object): def __init__(self, **kw): super(_NumericType, self).__init__(**kw) + class _FloatType(_NumericType, sqltypes.Float): def __init__(self, precision=None, scale=None, asdecimal=True, **kw): if isinstance(self, (REAL, DOUBLE)) and \ @@ -147,23 +46,22 @@ class _FloatType(_NumericType, sqltypes.Float): (precision is None and scale is not None) or (precision is not None and scale is None) ): - raise exc.ArgumentError( - "You must specify both precision and scale or omit " - "both altogether.") + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether.") - super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw) + super(_FloatType, self).__init__(precision=precision, + asdecimal=asdecimal, **kw) self.scale = scale + class _StringType(mysql_dialect._StringType): """Base for Drizzle string types.""" - def __init__(self, collation=None, - binary=False, - **kw): + def __init__(self, collation=None, binary=False, **kw): kw['national'] = False - super(_StringType, self).__init__(collation=collation, - binary=binary, - **kw) + super(_StringType, self).__init__(collation=collation, binary=binary, + **kw) class NUMERIC(_NumericType, sqltypes.NUMERIC): @@ -180,7 +78,9 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC): :param scale: The number of digits after the decimal point. """ - super(NUMERIC, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + + super(NUMERIC, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) class DECIMAL(_NumericType, sqltypes.DECIMAL): @@ -215,9 +115,11 @@ class DOUBLE(_FloatType): :param scale: The number of digits after the decimal point. """ + super(DOUBLE, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + class REAL(_FloatType, sqltypes.REAL): """Drizzle REAL type.""" @@ -232,9 +134,11 @@ class REAL(_FloatType, sqltypes.REAL): :param scale: The number of digits after the decimal point. """ + super(REAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + class FLOAT(_FloatType, sqltypes.FLOAT): """Drizzle FLOAT type.""" @@ -249,42 +153,46 @@ class FLOAT(_FloatType, sqltypes.FLOAT): :param scale: The number of digits after the decimal point. """ + super(FLOAT, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) def bind_processor(self, dialect): return None + class INTEGER(sqltypes.INTEGER): """Drizzle INTEGER type.""" __visit_name__ = 'INTEGER' def __init__(self, **kw): - """Construct an INTEGER. + """Construct an INTEGER.""" - """ super(INTEGER, self).__init__(**kw) + class BIGINT(sqltypes.BIGINT): """Drizzle BIGINTEGER type.""" __visit_name__ = 'BIGINT' def __init__(self, **kw): - """Construct a BIGINTEGER. + """Construct a BIGINTEGER.""" - """ super(BIGINT, self).__init__(**kw) class _DrizzleTime(mysql_dialect._MSTime): """Drizzle TIME type.""" + class TIMESTAMP(sqltypes.TIMESTAMP): """Drizzle TIMESTAMP type.""" + __visit_name__ = 'TIMESTAMP' + class TEXT(_StringType, sqltypes.TEXT): """Drizzle TEXT type, for text up to 2^16 characters.""" @@ -306,8 +214,10 @@ class TEXT(_StringType, sqltypes.TEXT): only the collation of character data. """ + super(TEXT, self).__init__(length=length, **kw) + class VARCHAR(_StringType, sqltypes.VARCHAR): """Drizzle VARCHAR type, for variable-length character data.""" @@ -325,8 +235,10 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): only the collation of character data. """ + super(VARCHAR, self).__init__(length=length, **kwargs) + class CHAR(_StringType, sqltypes.CHAR): """Drizzle CHAR type, for fixed-length character data.""" @@ -345,8 +257,10 @@ class CHAR(_StringType, sqltypes.CHAR): compatible with the national character set. """ + super(CHAR, self).__init__(length=length, **kwargs) + class ENUM(mysql_dialect.ENUM): """Drizzle ENUM type.""" @@ -363,8 +277,9 @@ class ENUM(mysql_dialect.ENUM): :param strict: Defaults to False: ensure that a given value is in this ENUM's range of permissible values when inserting or updating rows. - Note that Drizzle will not raise a fatal error if you attempt to store - an out of range value- an alternate value will be stored instead. + Note that Drizzle will not raise a fatal error if you attempt to + store an out of range value- an alternate value will be stored + instead. (See Drizzle ENUM documentation.) :param collation: Optional, a column-level collation for this string @@ -390,12 +305,15 @@ class ENUM(mysql_dialect.ENUM): literals for you. This is a transitional option. """ + super(ENUM, self).__init__(*enums, **kw) + class _DrizzleBoolean(sqltypes.Boolean): def get_dbapi_type(self, dbapi): return dbapi.NUMERIC + colspecs = { sqltypes.Numeric: NUMERIC, sqltypes.Float: FLOAT, @@ -404,6 +322,7 @@ colspecs = { sqltypes.Boolean: _DrizzleBoolean, } + # All the types we have in Drizzle ischema_names = { 'BIGINT': BIGINT, @@ -427,6 +346,7 @@ ischema_names = { 'VARCHAR': VARCHAR, } + class DrizzleCompiler(mysql_dialect.MySQLCompiler): def visit_typeclause(self, typeclause): @@ -439,7 +359,7 @@ class DrizzleCompiler(mysql_dialect.MySQLCompiler): def visit_cast(self, cast, **kwargs): type_ = self.process(cast.typeclause) if type_ is None: - return self.process(cast.clause) + return self.process(cast.clause) return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) @@ -447,12 +367,13 @@ class DrizzleCompiler(mysql_dialect.MySQLCompiler): class DrizzleDDLCompiler(mysql_dialect.MySQLDDLCompiler): pass + class DrizzleTypeCompiler(mysql_dialect.MySQLTypeCompiler): def _extend_numeric(self, type_, spec): return spec def _extend_string(self, type_, defaults, spec): - """Extend a string-type declaration with standard SQL + """Extend a string-type declaration with standard SQL COLLATE annotations and Drizzle specific extensions. """ @@ -492,11 +413,16 @@ class DrizzleTypeCompiler(mysql_dialect.MySQLTypeCompiler): class DrizzleExecutionContext(mysql_dialect.MySQLExecutionContext): pass + class DrizzleIdentifierPreparer(mysql_dialect.MySQLIdentifierPreparer): pass + class DrizzleDialect(mysql_dialect.MySQLDialect): - """Details of the Drizzle dialect. Not used directly in application code.""" + """Details of the Drizzle dialect. + + Not used directly in application code. + """ name = 'drizzle' @@ -505,7 +431,6 @@ class DrizzleDialect(mysql_dialect.MySQLDialect): supports_native_boolean = True supports_views = False - default_paramstyle = 'format' colspecs = colspecs @@ -516,8 +441,8 @@ class DrizzleDialect(mysql_dialect.MySQLDialect): preparer = DrizzleIdentifierPreparer def on_connect(self): - """Force autocommit - Drizzle Bug#707842 doesn't set this - properly""" + """Force autocommit - Drizzle Bug#707842 doesn't set this properly""" + def connect(conn): conn.autocommit(False) return connect @@ -535,6 +460,7 @@ class DrizzleDialect(mysql_dialect.MySQLDialect): @reflection.cache def get_table_names(self, connection, schema=None, **kw): """Return a Unicode SHOW TABLES from a given schema.""" + if schema is not None: current_schema = schema else: @@ -554,8 +480,8 @@ class DrizzleDialect(mysql_dialect.MySQLDialect): Cached per-connection. This value can not change without a server restart. - """ + return 0 def _detect_collations(self, connection): @@ -566,7 +492,9 @@ class DrizzleDialect(mysql_dialect.MySQLDialect): collations = {} charset = self._connection_charset - rs = connection.execute('SELECT CHARACTER_SET_NAME, COLLATION_NAME from data_dictionary.COLLATIONS') + rs = connection.execute( + 'SELECT CHARACTER_SET_NAME, COLLATION_NAME FROM' + ' data_dictionary.COLLATIONS') for row in self._compat_fetchall(rs, charset): collations[row[0]] = row[1] return collations @@ -575,8 +503,7 @@ class DrizzleDialect(mysql_dialect.MySQLDialect): """Detect and adjust for the ANSI_QUOTES sql mode.""" self._server_ansiquotes = False - self._backslash_escapes = False -log.class_logger(DrizzleDialect) +log.class_logger(DrizzleDialect) diff --git a/lib/sqlalchemy/dialects/drizzle/mysqldb.py b/lib/sqlalchemy/dialects/drizzle/mysqldb.py index 01116fa93..ce9518a81 100644 --- a/lib/sqlalchemy/dialects/drizzle/mysqldb.py +++ b/lib/sqlalchemy/dialects/drizzle/mysqldb.py @@ -1,11 +1,9 @@ -"""Support for the Drizzle database via the Drizzle-python adapter. +"""Support for the Drizzle database via the mysql-python adapter. -Drizzle-Python is available at: +MySQL-Python is available at: http://sourceforge.net/projects/mysql-python -At least version 1.2.1 or 1.2.2 should be used. - Connecting ----------- @@ -13,37 +11,22 @@ Connect string format:: drizzle+mysqldb://<user>:<password>@<host>[:<port>]/<dbname> -Unicode -------- - -Drizzle accommodates Python ``unicode`` objects directly and -uses the ``utf8`` encoding in all cases. - -Known Issues -------------- - -Drizzle-python at least as of version 1.2.2 has a serious memory leak related -to unicode conversion, a feature which is disabled via ``use_unicode=0``. -The recommended connection form with SQLAlchemy is:: - - engine = create_engine('mysql://scott:tiger@localhost/test?charset=utf8&use_unicode=0', pool_recycle=3600) - - """ -from sqlalchemy.dialects.drizzle.base import (DrizzleDialect, - DrizzleExecutionContext, - DrizzleCompiler, DrizzleIdentifierPreparer) +from sqlalchemy.dialects.drizzle.base import ( + DrizzleDialect, + DrizzleExecutionContext, + DrizzleCompiler, + DrizzleIdentifierPreparer) from sqlalchemy.connectors.mysqldb import ( - MySQLDBExecutionContext, - MySQLDBCompiler, - MySQLDBIdentifierPreparer, - MySQLDBConnector - ) - -class DrizzleExecutionContext_mysqldb( - MySQLDBExecutionContext, - DrizzleExecutionContext): + MySQLDBExecutionContext, + MySQLDBCompiler, + MySQLDBIdentifierPreparer, + MySQLDBConnector) + + +class DrizzleExecutionContext_mysqldb(MySQLDBExecutionContext, + DrizzleExecutionContext): pass @@ -51,11 +34,11 @@ class DrizzleCompiler_mysqldb(MySQLDBCompiler, DrizzleCompiler): pass -class DrizzleIdentifierPreparer_mysqldb( - MySQLDBIdentifierPreparer, - DrizzleIdentifierPreparer): +class DrizzleIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer, + DrizzleIdentifierPreparer): pass + class DrizzleDialect_mysqldb(MySQLDBConnector, DrizzleDialect): execution_ctx_cls = DrizzleExecutionContext_mysqldb statement_compiler = DrizzleCompiler_mysqldb @@ -63,6 +46,7 @@ class DrizzleDialect_mysqldb(MySQLDBConnector, DrizzleDialect): def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" + return 'utf8' diff --git a/lib/sqlalchemy/dialects/informix/base.py b/lib/sqlalchemy/dialects/informix/base.py index 57bb9c379..044fb525d 100644 --- a/lib/sqlalchemy/dialects/informix/base.py +++ b/lib/sqlalchemy/dialects/informix/base.py @@ -23,7 +23,7 @@ from sqlalchemy import types as sqltypes RESERVED_WORDS = set( ["abs", "absolute", "access", "access_method", "acos", "active", "add", "address", "add_months", "admin", "after", "aggregate", "alignment", - "all", "allocate", "all_rows", "altere", "and", "ansi", "any", "append", + "all", "allocate", "all_rows", "alter", "and", "ansi", "any", "append", "array", "as", "asc", "ascii", "asin", "at", "atan", "atan2", "attach", "attributes", "audit", "authentication", "authid", "authorization", "authorized", "auto", "autofree", "auto_reprepare", "auto_stat_mode", diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index d63286bba..4d0af7cbe 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -146,7 +146,7 @@ Enabling Snapshot Isolation Not necessarily specific to SQLAlchemy, SQL Server has a default transaction isolation mode that locks entire tables, and causes even mildly concurrent -applications to have long held locks and frequent deadlocks. +applications to have long held locks and frequent deadlocks. Enabling snapshot isolation for the database as a whole is recommended for modern levels of concurrency support. This is accomplished via the following ALTER DATABASE commands executed at the SQL prompt:: @@ -158,29 +158,6 @@ following ALTER DATABASE commands executed at the SQL prompt:: Background on SQL Server snapshot isolation is available at http://msdn.microsoft.com/en-us/library/ms175095.aspx. -Scalar Select Comparisons -------------------------- - -The MSSQL dialect contains a legacy behavior whereby comparing -a scalar select to a value using the ``=`` or ``!=`` operator -will resolve to IN or NOT IN, respectively. This behavior is -deprecated and will be removed in 0.8 - the ``s.in_()``/``~s.in_()`` operators -should be used when IN/NOT IN are desired. - -For the time being, the existing behavior prevents a comparison -between scalar select and another value that actually wants to use ``=``. -To remove this behavior in a forwards-compatible way, apply this -compilation rule by placing the following code at the module import -level:: - - from sqlalchemy.ext.compiler import compiles - from sqlalchemy.sql.expression import _BinaryExpression - from sqlalchemy.sql.compiler import SQLCompiler - - @compiles(_BinaryExpression, 'mssql') - def override_legacy_binary(element, compiler, **kw): - return SQLCompiler.visit_binary(compiler, element, **kw) - Known Issues ------------ @@ -689,18 +666,22 @@ class MSExecutionContext(default.DefaultExecutionContext): not self.executemany if self._enable_identity_insert: - self.cursor.execute("SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl)) + self.root_connection._cursor_execute(self.cursor, + "SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl), + ()) def post_exec(self): """Disable IDENTITY_INSERT if enabled.""" + conn = self.root_connection if self._select_lastrowid: if self.dialect.use_scope_identity: - self.cursor.execute( - "SELECT scope_identity() AS lastrowid", ()) + conn._cursor_execute(self.cursor, + "SELECT scope_identity() AS lastrowid", ()) else: - self.cursor.execute("SELECT @@identity AS lastrowid", ()) + conn._cursor_execute(self.cursor, + "SELECT @@identity AS lastrowid", ()) # fetchall() ensures the cursor is consumed without closing it row = self.cursor.fetchall()[0] self._lastrowid = int(row[0]) @@ -710,10 +691,11 @@ class MSExecutionContext(default.DefaultExecutionContext): self._result_proxy = base.FullyBufferedResultProxy(self) if self._enable_identity_insert: - self.cursor.execute( + conn._cursor_execute(self.cursor, "SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer. - format_table(self.compiled.statement.table) + format_table(self.compiled.statement.table), + () ) def get_lastrowid(self): @@ -874,7 +856,9 @@ class MSSQLCompiler(compiler.SQLCompiler): t, column) if result_map is not None: - result_map[column.name.lower()] = \ + result_map[column.name + if self.dialect.case_sensitive + else column.name.lower()] = \ (column.name, (column, ), column.type) @@ -901,31 +885,7 @@ class MSSQLCompiler(compiler.SQLCompiler): binary.left, binary.operator), **kwargs) - else: - if ( - (binary.operator is operator.eq or - binary.operator is operator.ne) - and ( - (isinstance(binary.left, expression._FromGrouping) - and isinstance(binary.left.element, - expression._ScalarSelect)) - or (isinstance(binary.right, expression._FromGrouping) - and isinstance(binary.right.element, - expression._ScalarSelect)) - or isinstance(binary.left, expression._ScalarSelect) - or isinstance(binary.right, expression._ScalarSelect) - ) - ): - op = binary.operator == operator.eq and "IN" or "NOT IN" - util.warn_deprecated("Comparing a scalar select using ``=``/``!=`` will " - "no longer produce IN/NOT IN in 0.8. To remove this " - "behavior immediately, use the recipe at " - "http://www.sqlalchemy.org/docs/07/dialects/mssql.html#scalar-select-comparisons") - return self.process( - expression._BinaryExpression(binary.left, - binary.right, op), - **kwargs) - return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) def returning_clause(self, stmt, returning_cols): @@ -980,6 +940,22 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return "" + def update_from_clause(self, update_stmt, + from_table, extra_froms, + from_hints, + **kw): + """Render the UPDATE..FROM clause specific to MSSQL. + + In MSSQL, if the UPDATE statement involves an alias of the table to + be updated, then the table itself must be added to the FROM list as + well. Otherwise, it is optional. Here, we add it regardless. + + """ + return "FROM " + ', '.join( + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) + for t in [from_table] + extra_froms) + class MSSQLStrictCompiler(MSSQLCompiler): """A subclass of MSSQLCompiler which disables the usage of bind parameters where not allowed natively by MS-SQL. @@ -1326,6 +1302,7 @@ class MSDialect(default.DefaultDialect): whereclause = columns.c.table_name==tablename s = sql.select([columns], whereclause, order_by=[columns.c.ordinal_position]) + c = connection.execute(s) cols = [] while True: diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 434cfd43c..a7cb42aac 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -242,7 +242,8 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): def __init__(self, description_encoding='latin-1', **params): super(MSDialect_pyodbc, self).__init__(**params) self.description_encoding = description_encoding - self.use_scope_identity = self.dbapi and \ + self.use_scope_identity = self.use_scope_identity and \ + self.dbapi and \ hasattr(self.dbapi.Cursor, 'nextset') self._need_decimal_fix = self.dbapi and \ self._dbapi_version() < (2, 1, 8) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 10f56f751..a61d59e9b 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -344,9 +344,9 @@ class _FloatType(_NumericType, sqltypes.Float): (precision is None and scale is not None) or (precision is not None and scale is None) ): - raise exc.ArgumentError( - "You must specify both precision and scale or omit " - "both altogether.") + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether.") super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw) self.scale = scale @@ -1273,11 +1273,11 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_cast(self, cast, **kwargs): # No cast until 4, no decimals until 5. if not self.dialect._supports_cast: - return self.process(cast.clause) + return self.process(cast.clause.self_group()) type_ = self.process(cast.typeclause) if type_ is None: - return self.process(cast.clause) + return self.process(cast.clause.self_group()) return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) @@ -1395,8 +1395,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler): auto_inc_column is not list(table.primary_key)[0]: if constraint_string: constraint_string += ", \n\t" - constraint_string += "KEY `idx_autoinc_%s`(`%s`)" % (auto_inc_column.name, \ - self.preparer.format_column(auto_inc_column)) + constraint_string += "KEY %s (%s)" % ( + self.preparer.quote( + "idx_autoinc_%s" % auto_inc_column.name, None + ), + self.preparer.format_column(auto_inc_column) + ) return constraint_string @@ -2017,7 +2021,6 @@ class MySQLDialect(default.DefaultDialect): @reflection.cache def get_view_names(self, connection, schema=None, **kw): - charset = self._connection_charset if self.server_version_info < (5, 0, 2): raise NotImplementedError if schema is None: @@ -2028,7 +2031,7 @@ class MySQLDialect(default.DefaultDialect): rp = connection.execute("SHOW FULL TABLES FROM %s" % self.identifier_preparer.quote_identifier(schema)) return [row[0] for row in self._compat_fetchall(rp, charset=charset)\ - if row[1] == 'VIEW'] + if row[1] in ('VIEW', 'SYSTEM VIEW')] @reflection.cache def get_table_options(self, connection, table_name, schema=None, **kw): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index c6eff8411..d47b9e757 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -317,7 +317,7 @@ class UUID(sqltypes.TypeEngine): PGUuid = UUID -class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): +class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): """Postgresql ARRAY type. Represents values as Python lists. @@ -329,7 +329,7 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): """ __visit_name__ = 'ARRAY' - def __init__(self, item_type, mutable=False, as_tuple=False): + def __init__(self, item_type, as_tuple=False): """Construct an ARRAY. E.g.:: @@ -344,25 +344,10 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): ``ARRAY(ARRAY(Integer))`` or such. The type mapping figures out on the fly - :param mutable=False: Specify whether lists passed to this - class should be considered mutable - this enables - "mutable types" mode in the ORM. Be sure to read the - notes for :class:`.MutableType` regarding ORM - performance implications (default changed from ``True`` in - 0.7.0). - - .. note:: - - This functionality is now superseded by the - ``sqlalchemy.ext.mutable`` extension described in - :ref:`mutable_toplevel`. - :param as_tuple=False: Specify whether return results should be converted to tuples from lists. DBAPIs such as psycopg2 return lists by default. When tuples are - returned, the results are hashable. This flag can only - be set to ``True`` when ``mutable`` is set to - ``False``. (new in 0.6.5) + returned, the results are hashable. """ if isinstance(item_type, ARRAY): @@ -371,27 +356,11 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): if isinstance(item_type, type): item_type = item_type() self.item_type = item_type - self.mutable = mutable - if mutable and as_tuple: - raise exc.ArgumentError( - "mutable must be set to False if as_tuple is True." - ) self.as_tuple = as_tuple - def copy_value(self, value): - if value is None: - return None - elif self.mutable: - return list(value) - else: - return value - def compare_values(self, x, y): return x == y - def is_mutable(self): - return self.mutable - def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect) if item_proc: diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 4c7d0173e..754bf7966 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -83,10 +83,7 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): The default string storage format is:: - "%04d-%02d-%02d %02d:%02d:%02d.%06d" % (value.year, - value.month, value.day, - value.hour, value.minute, - value.second, value.microsecond) + "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(min)02d:%(second)02d.%(microsecond)06d" e.g.:: @@ -99,22 +96,38 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): from sqlalchemy.dialects.sqlite import DATETIME dt = DATETIME( - storage_format="%04d/%02d/%02d %02d-%02d-%02d-%06d", - regexp=re.compile("(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)(?:-(\d+))?") + storage_format="%(year)04d/%(month)02d/%(day)02d %(hour)02d:%(min)02d:%(second)02d", + regexp=re.compile("(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)") ) :param storage_format: format string which will be applied to the - tuple ``(value.year, value.month, value.day, value.hour, - value.minute, value.second, value.microsecond)``, given a - Python datetime.datetime() object. + dict with keys year, month, day, hour, minute, second, and microsecond. :param regexp: regular expression which will be applied to - incoming result rows. The resulting match object is applied to - the Python datetime() constructor via ``*map(int, - match_obj.groups(0))``. + incoming result rows. If the regexp contains named groups, the + resulting match dict is applied to the Python datetime() constructor + as keyword arguments. Otherwise, if positional groups are used, the + the datetime() constructor is called with positional arguments via + ``*map(int, match_obj.groups(0))``. """ - _storage_format = "%04d-%02d-%02d %02d:%02d:%02d.%06d" + _storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + ) + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop('truncate_microseconds', False) + super(DATETIME, self).__init__(*args, **kwargs) + if truncate_microseconds: + assert 'storage_format' not in kwargs, "You can specify only "\ + "one of truncate_microseconds or storage_format." + assert 'regexp' not in kwargs, "You can specify only one of "\ + "truncate_microseconds or regexp." + self._storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d" + ) def bind_processor(self, dialect): datetime_datetime = datetime.datetime @@ -124,12 +137,25 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): if value is None: return None elif isinstance(value, datetime_datetime): - return format % (value.year, value.month, value.day, - value.hour, value.minute, value.second, - value.microsecond) + return format % { + 'year': value.year, + 'month': value.month, + 'day': value.day, + 'hour': value.hour, + 'minute': value.minute, + 'second': value.second, + 'microsecond': value.microsecond, + } elif isinstance(value, datetime_date): - return format % (value.year, value.month, value.day, - 0, 0, 0, 0) + return format % { + 'year': value.year, + 'month': value.month, + 'day': value.day, + 'hour': 0, + 'minute': 0, + 'second': 0, + 'microsecond': 0, + } else: raise TypeError("SQLite DateTime type only accepts Python " "datetime and date objects as input.") @@ -147,7 +173,7 @@ class DATE(_DateTimeMixin, sqltypes.Date): The default string storage format is:: - "%04d-%02d-%02d" % (value.year, value.month, value.day) + "%(year)04d-%(month)02d-%(day)02d" e.g.:: @@ -160,22 +186,22 @@ class DATE(_DateTimeMixin, sqltypes.Date): from sqlalchemy.dialects.sqlite import DATE d = DATE( - storage_format="%02d/%02d/%02d", - regexp=re.compile("(\d+)/(\d+)/(\d+)") + storage_format="%(month)02d/%(day)02d/%(year)04d", + regexp=re.compile("(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+)") ) :param storage_format: format string which will be applied to the - tuple ``(value.year, value.month, value.day)``, - given a Python datetime.date() object. + dict with keys year, month, and day. :param regexp: regular expression which will be applied to - incoming result rows. The resulting match object is applied to - the Python date() constructor via ``*map(int, - match_obj.groups(0))``. - + incoming result rows. If the regexp contains named groups, the + resulting match dict is applied to the Python date() constructor + as keyword arguments. Otherwise, if positional groups are used, the + the date() constructor is called with positional arguments via + ``*map(int, match_obj.groups(0))``. """ - _storage_format = "%04d-%02d-%02d" + _storage_format = "%(year)04d-%(month)02d-%(day)02d" def bind_processor(self, dialect): datetime_date = datetime.date @@ -184,7 +210,11 @@ class DATE(_DateTimeMixin, sqltypes.Date): if value is None: return None elif isinstance(value, datetime_date): - return format % (value.year, value.month, value.day) + return format % { + 'year': value.year, + 'month': value.month, + 'day': value.day, + } else: raise TypeError("SQLite Date type only accepts Python " "date objects as input.") @@ -202,9 +232,7 @@ class TIME(_DateTimeMixin, sqltypes.Time): The default string storage format is:: - "%02d:%02d:%02d.%06d" % (value.hour, value.minute, - value.second, - value.microsecond) + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" e.g.:: @@ -217,22 +245,32 @@ class TIME(_DateTimeMixin, sqltypes.Time): from sqlalchemy.dialects.sqlite import TIME t = TIME( - storage_format="%02d-%02d-%02d-%06d", + storage_format="%(hour)02d-%(minute)02d-%(second)02d-%(microsecond)06d", regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?") ) - :param storage_format: format string which will be applied - to the tuple ``(value.hour, value.minute, value.second, - value.microsecond)``, given a Python datetime.time() object. + :param storage_format: format string which will be applied to the + dict with keys hour, minute, second, and microsecond. :param regexp: regular expression which will be applied to - incoming result rows. The resulting match object is applied to - the Python time() constructor via ``*map(int, - match_obj.groups(0))``. - + incoming result rows. If the regexp contains named groups, the + resulting match dict is applied to the Python time() constructor + as keyword arguments. Otherwise, if positional groups are used, the + the time() constructor is called with positional arguments via + ``*map(int, match_obj.groups(0))``. """ - _storage_format = "%02d:%02d:%02d.%06d" + _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop('truncate_microseconds', False) + super(TIME, self).__init__(*args, **kwargs) + if truncate_microseconds: + assert 'storage_format' not in kwargs, "You can specify only "\ + "one of truncate_microseconds or storage_format." + assert 'regexp' not in kwargs, "You can specify only one of "\ + "truncate_microseconds or regexp." + self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d" def bind_processor(self, dialect): datetime_time = datetime.time @@ -241,8 +279,12 @@ class TIME(_DateTimeMixin, sqltypes.Time): if value is None: return None elif isinstance(value, datetime_time): - return format % (value.hour, value.minute, value.second, - value.microsecond) + return format % { + 'hour': value.hour, + 'minute': value.minute, + 'second': value.second, + 'microsecond': value.microsecond, + } else: raise TypeError("SQLite Time type only accepts Python " "time objects as input.") diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 23b4b0b3b..c3667dd33 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -143,6 +143,12 @@ def create_engine(*args, **kwargs): :class:`.String` type - see that type for further details. + :param case_sensitive=True: if False, result column names + will match in a case-insensitive fashion, that is, + ``row['SomeColumn']``. By default, result row names + match case-sensitively as of version 0.8. In version + 0.7 and prior, all matches were case-insensitive. + :param connect_args: a dictionary of options which will be passed directly to the DBAPI's ``connect()`` method as additional keyword arguments. See the example diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 10ce6d819..a2695e337 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -259,6 +259,18 @@ class Dialect(object): raise NotImplementedError() + def get_primary_keys(self, connection, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + + Deprecated. This method is only called by the default + implementation of :meth:`get_pk_constraint()`. Dialects should + instead implement this method directly. + + """ + + raise NotImplementedError() + def get_pk_constraint(self, connection, table_name, schema=None, **kw): """Return information about the primary key constraint on table_name`. @@ -1797,10 +1809,19 @@ class Connection(Connectable): (statement is not None and context is None) if should_wrap and context: + if self._has_events: + self.engine.dispatch.dbapi_error(self, + cursor, + statement, + parameters, + context, + e) context.handle_dbapi_exception(e) is_disconnect = isinstance(e, self.dialect.dbapi.Error) and \ self.dialect.is_disconnect(e, self.__connection, cursor) + + if is_disconnect: self.invalidate(e) self.engine.dispose() @@ -2337,7 +2358,11 @@ class Engine(Connectable, log.Identified): """ conn = self.contextual_connect(close_with_result=close_with_result) - trans = conn.begin() + try: + trans = conn.begin() + except: + conn.close() + raise return Engine._trans_ctx(conn, trans, close_with_result) def transaction(self, callable_, *args, **kwargs): @@ -2702,6 +2727,7 @@ class ResultMetaData(object): dialect = context.dialect typemap = dialect.dbapi_type_map translate_colname = dialect._translate_colname + self.case_sensitive = dialect.case_sensitive # high precedence key values. primary_keymap = {} @@ -2716,9 +2742,14 @@ class ResultMetaData(object): if translate_colname: colname, untranslated = translate_colname(colname) + if dialect.requires_name_normalize: + colname = dialect.normalize_name(colname) + if context.result_map: try: - name, obj, type_ = context.result_map[colname.lower()] + name, obj, type_ = context.result_map[colname + if self.case_sensitive + else colname.lower()] except KeyError: name, obj, type_ = \ colname, None, typemap.get(coltype, types.NULLTYPE) @@ -2736,17 +2767,20 @@ class ResultMetaData(object): primary_keymap[i] = rec # populate primary keymap, looking for conflicts. - if primary_keymap.setdefault(name.lower(), rec) is not rec: + if primary_keymap.setdefault( + name if self.case_sensitive + else name.lower(), + rec) is not rec: # place a record that doesn't have the "index" - this # is interpreted later as an AmbiguousColumnError, # but only when actually accessed. Columns # colliding by name is not a problem if those names # aren't used; integer and ColumnElement access is always # unambiguous. - primary_keymap[name.lower()] = (processor, obj, None) + primary_keymap[name + if self.case_sensitive + else name.lower()] = (processor, obj, None) - if dialect.requires_name_normalize: - colname = dialect.normalize_name(colname) self.keys.append(colname) if obj: @@ -2775,7 +2809,9 @@ class ResultMetaData(object): row. """ - rec = (processor, obj, i) = self._keymap[origname.lower()] + rec = (processor, obj, i) = self._keymap[origname if + self.case_sensitive + else origname.lower()] if self._keymap.setdefault(name, rec) is not rec: self._keymap[name] = (processor, obj, None) @@ -2783,17 +2819,27 @@ class ResultMetaData(object): map = self._keymap result = None if isinstance(key, basestring): - result = map.get(key.lower()) + result = map.get(key if self.case_sensitive else key.lower()) # fallback for targeting a ColumnElement to a textual expression # this is a rare use case which only occurs when matching text() # or colummn('name') constructs to ColumnElements, or after a # pickle/unpickle roundtrip elif isinstance(key, expression.ColumnElement): - if key._label and key._label.lower() in map: - result = map[key._label.lower()] - elif hasattr(key, 'name') and key.name.lower() in map: + if key._label and ( + key._label + if self.case_sensitive + else key._label.lower()) in map: + result = map[key._label + if self.case_sensitive + else key._label.lower()] + elif hasattr(key, 'name') and ( + key.name + if self.case_sensitive + else key.name.lower()) in map: # match is only on name. - result = map[key.name.lower()] + result = map[key.name + if self.case_sensitive + else key.name.lower()] # search extra hard to make sure this # isn't a column/label name overlap. # this check isn't currently available if the row @@ -2829,7 +2875,8 @@ class ResultMetaData(object): for key, (processor, obj, index) in self._keymap.iteritems() if isinstance(key, (basestring, int)) ), - 'keys': self.keys + 'keys': self.keys, + "case_sensitive":self.case_sensitive, } def __setstate__(self, state): @@ -2842,6 +2889,7 @@ class ResultMetaData(object): # proxy comparison fails with the unpickle keymap[key] = (None, None, index) self.keys = state['keys'] + self.case_sensitive = state['case_sensitive'] self._echo = False diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index cbf33dcaf..1f72d005d 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -105,6 +105,7 @@ class DefaultDialect(base.Dialect): def __init__(self, convert_unicode=False, assert_unicode=False, encoding='utf-8', paramstyle=None, dbapi=None, implicit_returning=None, + case_sensitive=True, label_length=None, **kwargs): if not getattr(self, 'ported_sqla_06', True): @@ -139,6 +140,8 @@ class DefaultDialect(base.Dialect): self.identifier_preparer = self.preparer(self) self.type_compiler = self.type_compiler(self) + self.case_sensitive = case_sensitive + if label_length and label_length > self.max_identifier_length: raise exc.ArgumentError( "Label length of %d is greater than this dialect's" @@ -263,6 +266,17 @@ class DefaultDialect(base.Dialect): insp = reflection.Inspector.from_engine(connection) return insp.reflecttable(table, include_columns, exclude_columns) + def get_pk_constraint(self, conn, table_name, schema=None, **kw): + """Compatibility method, adapts the result of get_primary_keys() + for those dialects which don't implement get_pk_constraint(). + + """ + return { + 'constrained_columns': + self.get_primary_keys(conn, table_name, + schema=schema, **kw) + } + def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index f94e9ee16..13a7e1b88 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -31,7 +31,8 @@ from sqlalchemy import util from sqlalchemy.types import TypeEngine from sqlalchemy.util import deprecated from sqlalchemy.util import topological - +from sqlalchemy import inspection +from sqlalchemy.engine.base import Connectable @util.decorator def cache(fn, self, con, *args, **kw): @@ -119,6 +120,10 @@ class Inspector(object): return bind.dialect.inspector(bind) return Inspector(bind) + @inspection._inspects(Connectable) + def _insp(bind): + return Inspector.from_engine(bind) + @property def default_schema_name(self): """Return the default schema name presented by the dialect @@ -238,10 +243,9 @@ class Inspector(object): primary key information as a list of column names. """ - pkeys = self.dialect.get_pk_constraint(self.bind, table_name, schema, + return self.dialect.get_pk_constraint(self.bind, table_name, schema, info_cache=self.info_cache, **kw)['constrained_columns'] - return pkeys def get_pk_constraint(self, table_name, schema=None, **kw): """Return information about primary key constraint on `table_name`. @@ -256,12 +260,10 @@ class Inspector(object): optional name of the primary key constraint. """ - pkeys = self.dialect.get_pk_constraint(self.bind, table_name, schema, + return self.dialect.get_pk_constraint(self.bind, table_name, schema, info_cache=self.info_cache, **kw) - return pkeys - def get_foreign_keys(self, table_name, schema=None, **kw): """Return information about foreign_keys in `table_name`. @@ -290,10 +292,9 @@ class Inspector(object): """ - fk_defs = self.dialect.get_foreign_keys(self.bind, table_name, schema, + return self.dialect.get_foreign_keys(self.bind, table_name, schema, info_cache=self.info_cache, **kw) - return fk_defs def get_indexes(self, table_name, schema=None, **kw): """Return information about indexes in `table_name`. @@ -314,10 +315,9 @@ class Inspector(object): other options passed to the dialect's get_indexes() method. """ - indexes = self.dialect.get_indexes(self.bind, table_name, + return self.dialect.get_indexes(self.bind, table_name, schema, info_cache=self.info_cache, **kw) - return indexes def reflecttable(self, table, include_columns, exclude_columns=()): """Given a Table object, load its internal constructs based on introspection. @@ -371,7 +371,7 @@ class Inspector(object): found_table = False for col_d in self.get_columns(table_name, schema, **tblkw): found_table = True - table.dispatch.column_reflect(table, col_d) + table.dispatch.column_reflect(self, table, col_d) name = col_d['name'] if include_columns and name not in include_columns: diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 392ecda11..5bbdb9d65 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -14,6 +14,7 @@ be used directly and is also accepted directly by ``create_engine()``. import re, urllib from sqlalchemy import exc, util +from sqlalchemy.engine import base class URL(object): @@ -96,49 +97,21 @@ class URL(object): to this URL's driver name. """ - try: - if '+' in self.drivername: - dialect, driver = self.drivername.split('+') - else: - dialect, driver = self.drivername, 'base' - - module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects - module = getattr(module, dialect) - if hasattr(module, driver): - module = getattr(module, driver) - else: - module = self._load_entry_point() - if module is None: - raise exc.ArgumentError( - "Could not determine dialect for '%s'." % - self.drivername) - - return module.dialect - except ImportError: - module = self._load_entry_point() - if module is not None: - return module - else: - raise exc.ArgumentError( - "Could not determine dialect for '%s'." % self.drivername) - - def _load_entry_point(self): - """attempt to load this url's dialect from entry points, or return None - if pkg_resources is not installed or there is no matching entry point. - - Raise ImportError if the actual load fails. - - """ - try: - import pkg_resources - except ImportError: - return None - - for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'): - if res.name == self.drivername.replace("+", "."): - return res.load() + if '+' not in self.drivername: + name = self.drivername + else: + name = self.drivername.replace('+', '.') + from sqlalchemy.dialects import registry + cls = registry.load(name) + # check for legacy dialects that + # would return a module with 'dialect' as the + # actual class + if hasattr(cls, 'dialect') and \ + isinstance(cls.dialect, type) and \ + issubclass(cls.dialect, base.Dialect): + return cls.dialect else: - return None + return cls def translate_connect_args(self, names=[], **kw): """Translate url attributes into a dictionary of connection arguments. diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 504dfe150..99af804f6 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -166,7 +166,7 @@ class DDLEvents(event.Events): """ - def column_reflect(self, table, column_info): + def column_reflect(self, inspector, table, column_info): """Called for each unit of 'column info' retrieved when a :class:`.Table` is being reflected. @@ -188,7 +188,7 @@ class DDLEvents(event.Events): from sqlalchemy.schema import Table from sqlalchemy import event - def listen_for_reflect(table, column_info): + def listen_for_reflect(inspector, table, column_info): "receive a column_reflect event" # ... @@ -200,7 +200,7 @@ class DDLEvents(event.Events): ...or with a specific :class:`.Table` instance using the ``listeners`` argument:: - def listen_for_reflect(table, column_info): + def listen_for_reflect(inspector, table, column_info): "receive a column_reflect event" # ... @@ -401,6 +401,36 @@ class ConnectionEvents(event.Events): parameters, context, executemany): """Intercept low-level cursor execute() events.""" + def dbapi_error(self, conn, cursor, statement, parameters, + context, exception): + """Intercept a raw DBAPI error. + + This event is called with the DBAPI exception instance + received from the DBAPI itself, *before* SQLAlchemy wraps the + exception with it's own exception wrappers, and before any + other operations are performed on the DBAPI cursor; the + existing transaction remains in effect as well as any state + on the cursor. + + The use case here is to inject low-level exception handling + into an :class:`.Engine`, typically for logging and + debugging purposes. In general, user code should **not** modify + any state or throw any exceptions here as this will + interfere with SQLAlchemy's cleanup and error handling + routines. + + Subsequent to this hook, SQLAlchemy may attempt any + number of operations on the connection/cursor, including + closing the cursor, rolling back of the transaction in the + case of connectionless execution, and disposing of the entire + connection pool if a "disconnect" was detected. The + exception is then wrapped in a SQLAlchemy DBAPI exception + wrapper and re-thrown. + + New in 0.7.7. + + """ + def begin(self, conn): """Intercept begin() events.""" diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 91ffc2811..f28bd8a07 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -25,6 +25,13 @@ class ArgumentError(SQLAlchemyError): """ +class NoForeignKeysError(ArgumentError): + """Raised when no foreign keys can be located between two selectables + during a join.""" + +class AmbiguousForeignKeysError(ArgumentError): + """Raised when more than one foreign key matching can be located + between two selectables during a join.""" class CircularDependencyError(SQLAlchemyError): """Raised by topological sorts when a circular dependency is detected. diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 8b3416ea9..01a4a933f 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -15,7 +15,7 @@ See the example ``examples/association/proxied_association.py``. import itertools import operator import weakref -from sqlalchemy import exceptions +from sqlalchemy import exc from sqlalchemy import orm from sqlalchemy import util from sqlalchemy.orm import collections, ColumnProperty @@ -295,7 +295,7 @@ class AssociationProxy(object): elif self.collection_class is set: return _AssociationSet(lazy_collection, creator, getter, setter, self) else: - raise exceptions.ArgumentError( + raise exc.ArgumentError( 'could not guess which interface to use for ' 'collection_class "%s" backing "%s"; specify a ' 'proxy_factory and proxy_bulk_set manually' % @@ -323,7 +323,7 @@ class AssociationProxy(object): elif self.collection_class is set: proxy.update(values) else: - raise exceptions.ArgumentError( + raise exc.ArgumentError( 'no proxy_bulk_set supplied for custom ' 'collection_class implementation') @@ -405,7 +405,7 @@ class _lazy_collection(object): def __call__(self): obj = self.ref() if obj is None: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "stale association proxy, parent object has gone out of " "scope") return getattr(obj, self.target) diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index faf575da1..07a2a3b95 100755 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -1398,6 +1398,10 @@ class _GetTable(object): def _deferred_relationship(cls, prop): def resolve_arg(arg): import sqlalchemy + from sqlalchemy.orm import foreign, remote + + fallback = sqlalchemy.__dict__.copy() + fallback.update({'foreign':foreign, 'remote':remote}) def access_cls(key): if key in cls._decl_class_registry: @@ -1407,7 +1411,7 @@ def _deferred_relationship(cls, prop): elif key in cls.metadata._schemas: return _GetTable(key, cls.metadata) else: - return sqlalchemy.__dict__[key] + return fallback[key] d = util.PopulateDict(access_cls) def return_cls(): diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py deleted file mode 100644 index 3727f5757..000000000 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ /dev/null @@ -1,813 +0,0 @@ -# ext/sqlsoup.py -# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file> -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" - -.. note:: - - SQLSoup is now its own project. Documentation - and project status are available at: - - http://pypi.python.org/pypi/sqlsoup - - http://readthedocs.org/docs/sqlsoup - - SQLSoup will no longer be included with SQLAlchemy as of - version 0.8. - - -Introduction -============ - -SqlSoup provides a convenient way to access existing database -tables without having to declare table or mapper classes ahead -of time. It is built on top of the SQLAlchemy ORM and provides a -super-minimalistic interface to an existing database. - -SqlSoup effectively provides a coarse grained, alternative -interface to working with the SQLAlchemy ORM, providing a "self -configuring" interface for extremely rudimental operations. It's -somewhat akin to a "super novice mode" version of the ORM. While -SqlSoup can be very handy, users are strongly encouraged to use -the full ORM for non-trivial applications. - -Suppose we have a database with users, books, and loans tables -(corresponding to the PyWebOff dataset, if you're curious). - -Creating a SqlSoup gateway is just like creating an SQLAlchemy -engine:: - - >>> from sqlalchemy.ext.sqlsoup import SqlSoup - >>> db = SqlSoup('sqlite:///:memory:') - -or, you can re-use an existing engine:: - - >>> db = SqlSoup(engine) - -You can optionally specify a schema within the database for your -SqlSoup:: - - >>> db.schema = myschemaname - -Loading objects -=============== - -Loading objects is as easy as this:: - - >>> users = db.users.all() - >>> users.sort() - >>> users - [ - MappedUsers(name=u'Joe Student',email=u'student@example.edu', - password=u'student',classname=None,admin=0), - MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu', - password=u'basepair',classname=None,admin=1) - ] - -Of course, letting the database do the sort is better:: - - >>> db.users.order_by(db.users.name).all() - [ - MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu', - password=u'basepair',classname=None,admin=1), - MappedUsers(name=u'Joe Student',email=u'student@example.edu', - password=u'student',classname=None,admin=0) - ] - -Field access is intuitive:: - - >>> users[0].email - u'student@example.edu' - -Of course, you don't want to load all users very often. Let's -add a WHERE clause. Let's also switch the order_by to DESC while -we're at it:: - - >>> from sqlalchemy import or_, and_, desc - >>> where = or_(db.users.name=='Bhargan Basepair', db.users.email=='student@example.edu') - >>> db.users.filter(where).order_by(desc(db.users.name)).all() - [ - MappedUsers(name=u'Joe Student',email=u'student@example.edu', - password=u'student',classname=None,admin=0), - MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu', - password=u'basepair',classname=None,admin=1) - ] - -You can also use .first() (to retrieve only the first object -from a query) or .one() (like .first when you expect exactly one -user -- it will raise an exception if more were returned):: - - >>> db.users.filter(db.users.name=='Bhargan Basepair').one() - MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu', - password=u'basepair',classname=None,admin=1) - -Since name is the primary key, this is equivalent to - - >>> db.users.get('Bhargan Basepair') - MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu', - password=u'basepair',classname=None,admin=1) - -This is also equivalent to - - >>> db.users.filter_by(name='Bhargan Basepair').one() - MappedUsers(name=u'Bhargan Basepair',email=u'basepair@example.edu', - password=u'basepair',classname=None,admin=1) - -filter_by is like filter, but takes kwargs instead of full -clause expressions. This makes it more concise for simple -queries like this, but you can't do complex queries like the -or\_ above or non-equality based comparisons this way. - -Full query documentation ------------------------- - -Get, filter, filter_by, order_by, limit, and the rest of the -query methods are explained in detail in -:ref:`ormtutorial_querying`. - -Modifying objects -================= - -Modifying objects is intuitive:: - - >>> user = _ - >>> user.email = 'basepair+nospam@example.edu' - >>> db.commit() - -(SqlSoup leverages the sophisticated SQLAlchemy unit-of-work -code, so multiple updates to a single object will be turned into -a single ``UPDATE`` statement when you commit.) - -To finish covering the basics, let's insert a new loan, then -delete it:: - - >>> book_id = db.books.filter_by(title='Regional Variation in Moss').first().id - >>> db.loans.insert(book_id=book_id, user_name=user.name) - MappedLoans(book_id=2,user_name=u'Bhargan Basepair',loan_date=None) - - >>> loan = db.loans.filter_by(book_id=2, user_name='Bhargan Basepair').one() - >>> db.delete(loan) - >>> db.commit() - -You can also delete rows that have not been loaded as objects. -Let's do our insert/delete cycle once more, this time using the -loans table's delete method. (For SQLAlchemy experts: note that -no flush() call is required since this delete acts at the SQL -level, not at the Mapper level.) The same where-clause -construction rules apply here as to the select methods:: - - >>> db.loans.insert(book_id=book_id, user_name=user.name) - MappedLoans(book_id=2,user_name=u'Bhargan Basepair',loan_date=None) - >>> db.loans.delete(db.loans.book_id==2) - -You can similarly update multiple rows at once. This will change the -book_id to 1 in all loans whose book_id is 2:: - - >>> db.loans.filter_by(db.loans.book_id==2).update({'book_id':1}) - >>> db.loans.filter_by(book_id=1).all() - [MappedLoans(book_id=1,user_name=u'Joe Student', - loan_date=datetime.datetime(2006, 7, 12, 0, 0))] - - -Joins -===== - -Occasionally, you will want to pull out a lot of data from related -tables all at once. In this situation, it is far more efficient to -have the database perform the necessary join. (Here we do not have *a -lot of data* but hopefully the concept is still clear.) SQLAlchemy is -smart enough to recognize that loans has a foreign key to users, and -uses that as the join condition automatically:: - - >>> join1 = db.join(db.users, db.loans, isouter=True) - >>> join1.filter_by(name='Joe Student').all() - [ - MappedJoin(name=u'Joe Student',email=u'student@example.edu', - password=u'student',classname=None,admin=0,book_id=1, - user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0)) - ] - -If you're unfortunate enough to be using MySQL with the default MyISAM -storage engine, you'll have to specify the join condition manually, -since MyISAM does not store foreign keys. Here's the same join again, -with the join condition explicitly specified:: - - >>> db.join(db.users, db.loans, db.users.name==db.loans.user_name, isouter=True) - <class 'sqlalchemy.ext.sqlsoup.MappedJoin'> - -You can compose arbitrarily complex joins by combining Join objects -with tables or other joins. Here we combine our first join with the -books table:: - - >>> join2 = db.join(join1, db.books) - >>> join2.all() - [ - MappedJoin(name=u'Joe Student',email=u'student@example.edu', - password=u'student',classname=None,admin=0,book_id=1, - user_name=u'Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0), - id=1,title=u'Mustards I Have Known',published_year=u'1989', - authors=u'Jones') - ] - -If you join tables that have an identical column name, wrap your join -with `with_labels`, to disambiguate columns with their table name -(.c is short for .columns):: - - >>> db.with_labels(join1).c.keys() - [u'users_name', u'users_email', u'users_password', - u'users_classname', u'users_admin', u'loans_book_id', - u'loans_user_name', u'loans_loan_date'] - -You can also join directly to a labeled object:: - - >>> labeled_loans = db.with_labels(db.loans) - >>> db.join(db.users, labeled_loans, isouter=True).c.keys() - [u'name', u'email', u'password', u'classname', - u'admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date'] - - -Relationships -============= - -You can define relationships on SqlSoup classes: - - >>> db.users.relate('loans', db.loans) - -These can then be used like a normal SA property: - - >>> db.users.get('Joe Student').loans - [MappedLoans(book_id=1,user_name=u'Joe Student', - loan_date=datetime.datetime(2006, 7, 12, 0, 0))] - - >>> db.users.filter(~db.users.loans.any()).all() - [MappedUsers(name=u'Bhargan Basepair', - email='basepair+nospam@example.edu', - password=u'basepair',classname=None,admin=1)] - -relate can take any options that the relationship function -accepts in normal mapper definition: - - >>> del db._cache['users'] - >>> db.users.relate('loans', db.loans, order_by=db.loans.loan_date, cascade='all, delete-orphan') - -Advanced Use -============ - -Sessions, Transactions and Application Integration -------------------------------------------------- - -.. note:: - - Please read and understand this section thoroughly - before using SqlSoup in any web application. - -SqlSoup uses a ScopedSession to provide thread-local sessions. -You can get a reference to the current one like this:: - - >>> session = db.session - -The default session is available at the module level in SQLSoup, -via:: - - >>> from sqlalchemy.ext.sqlsoup import Session - -The configuration of this session is ``autoflush=True``, -``autocommit=False``. This means when you work with the SqlSoup -object, you need to call ``db.commit()`` in order to have -changes persisted. You may also call ``db.rollback()`` to roll -things back. - -Since the SqlSoup object's Session automatically enters into a -transaction as soon as it's used, it is *essential* that you -call ``commit()`` or ``rollback()`` on it when the work within a -thread completes. This means all the guidelines for web -application integration at :ref:`session_lifespan` must be -followed. - -The SqlSoup object can have any session or scoped session -configured onto it. This is of key importance when integrating -with existing code or frameworks such as Pylons. If your -application already has a ``Session`` configured, pass it to -your SqlSoup object:: - - >>> from myapplication import Session - >>> db = SqlSoup(session=Session) - -If the ``Session`` is configured with ``autocommit=True``, use -``flush()`` instead of ``commit()`` to persist changes - in this -case, the ``Session`` closes out its transaction immediately and -no external management is needed. ``rollback()`` is also not -available. Configuring a new SQLSoup object in "autocommit" mode -looks like:: - - >>> from sqlalchemy.orm import scoped_session, sessionmaker - >>> db = SqlSoup('sqlite://', session=scoped_session(sessionmaker(autoflush=False, expire_on_commit=False, autocommit=True))) - - -Mapping arbitrary Selectables ------------------------------ - -SqlSoup can map any SQLAlchemy :class:`.Selectable` with the map -method. Let's map an :func:`.expression.select` object that uses an aggregate -function; we'll use the SQLAlchemy :class:`.Table` that SqlSoup -introspected as the basis. (Since we're not mapping to a simple -table or join, we need to tell SQLAlchemy how to find the -*primary key* which just needs to be unique within the select, -and not necessarily correspond to a *real* PK in the database.):: - - >>> from sqlalchemy import select, func - >>> b = db.books._table - >>> s = select([b.c.published_year, func.count('*').label('n')], from_obj=[b], group_by=[b.c.published_year]) - >>> s = s.alias('years_with_count') - >>> years_with_count = db.map(s, primary_key=[s.c.published_year]) - >>> years_with_count.filter_by(published_year='1989').all() - [MappedBooks(published_year=u'1989',n=1)] - -Obviously if we just wanted to get a list of counts associated with -book years once, raw SQL is going to be less work. The advantage of -mapping a Select is reusability, both standalone and in Joins. (And if -you go to full SQLAlchemy, you can perform mappings like this directly -to your object models.) - -An easy way to save mapped selectables like this is to just hang them on -your db object:: - - >>> db.years_with_count = years_with_count - -Python is flexible like that! - -Raw SQL -------- - -SqlSoup works fine with SQLAlchemy's text construct, described -in :ref:`sqlexpression_text`. You can also execute textual SQL -directly using the `execute()` method, which corresponds to the -`execute()` method on the underlying `Session`. Expressions here -are expressed like ``text()`` constructs, using named parameters -with colons:: - - >>> rp = db.execute('select name, email from users where name like :name order by name', name='%Bhargan%') - >>> for name, email in rp.fetchall(): print name, email - Bhargan Basepair basepair+nospam@example.edu - -Or you can get at the current transaction's connection using -`connection()`. This is the raw connection object which can -accept any sort of SQL expression or raw SQL string passed to -the database:: - - >>> conn = db.connection() - >>> conn.execute("'select name, email from users where name like ? order by name'", '%Bhargan%') - -Dynamic table names -------------------- - -You can load a table whose name is specified at runtime with the -entity() method: - - >>> tablename = 'loans' - >>> db.entity(tablename) == db.loans - True - -entity() also takes an optional schema argument. If none is -specified, the default schema is used. - -""" - -from sqlalchemy import Table, MetaData, join -from sqlalchemy import schema, sql, util -from sqlalchemy.engine.base import Engine -from sqlalchemy.orm import scoped_session, sessionmaker, mapper, \ - class_mapper, relationship, session,\ - object_session, attributes -from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE -from sqlalchemy.exc import SQLAlchemyError, InvalidRequestError, ArgumentError -from sqlalchemy.sql import expression - - -__all__ = ['PKNotFoundError', 'SqlSoup'] - -Session = scoped_session(sessionmaker(autoflush=True, autocommit=False)) - -class AutoAdd(MapperExtension): - def __init__(self, scoped_session): - self.scoped_session = scoped_session - - def instrument_class(self, mapper, class_): - class_.__init__ = self._default__init__(mapper) - - def _default__init__(ext, mapper): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - return __init__ - - def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): - session = self.scoped_session() - state = attributes.instance_state(instance) - session._save_impl(state) - return EXT_CONTINUE - - def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): - sess = object_session(instance) - if sess: - sess.expunge(instance) - return EXT_CONTINUE - -class PKNotFoundError(SQLAlchemyError): - pass - -def _ddl_error(cls): - msg = 'SQLSoup can only modify mapped Tables (found: %s)' \ - % cls._table.__class__.__name__ - raise InvalidRequestError(msg) - -# metaclass is necessary to expose class methods with getattr, e.g. -# we want to pass db.users.select through to users._mapper.select -class SelectableClassType(type): - def insert(cls, **kwargs): - _ddl_error(cls) - - def __clause_element__(cls): - return cls._table - - def __getattr__(cls, attr): - if attr == '_query': - # called during mapper init - raise AttributeError() - return getattr(cls._query, attr) - -class TableClassType(SelectableClassType): - def insert(cls, **kwargs): - o = cls() - o.__dict__.update(kwargs) - return o - - def relate(cls, propname, *args, **kwargs): - class_mapper(cls)._configure_property(propname, relationship(*args, **kwargs)) - -def _is_outer_join(selectable): - if not isinstance(selectable, sql.Join): - return False - if selectable.isouter: - return True - return _is_outer_join(selectable.left) or _is_outer_join(selectable.right) - -def _selectable_name(selectable): - if isinstance(selectable, sql.Alias): - return _selectable_name(selectable.element) - elif isinstance(selectable, sql.Select): - return ''.join(_selectable_name(s) for s in selectable.froms) - elif isinstance(selectable, schema.Table): - return selectable.name.capitalize() - else: - x = selectable.__class__.__name__ - if x[0] == '_': - x = x[1:] - return x - -def _class_for_table(session, engine, selectable, base_cls, mapper_kwargs): - selectable = expression._clause_element_as_expr(selectable) - mapname = 'Mapped' + _selectable_name(selectable) - # Py2K - if isinstance(mapname, unicode): - engine_encoding = engine.dialect.encoding - mapname = mapname.encode(engine_encoding) - # end Py2K - - if isinstance(selectable, Table): - klass = TableClassType(mapname, (base_cls,), {}) - else: - klass = SelectableClassType(mapname, (base_cls,), {}) - - def _compare(self, o): - L = list(self.__class__.c.keys()) - L.sort() - t1 = [getattr(self, k) for k in L] - try: - t2 = [getattr(o, k) for k in L] - except AttributeError: - raise TypeError('unable to compare with %s' % o.__class__) - return t1, t2 - - # python2/python3 compatible system of - # __cmp__ - __lt__ + __eq__ - - def __lt__(self, o): - t1, t2 = _compare(self, o) - return t1 < t2 - - def __eq__(self, o): - t1, t2 = _compare(self, o) - return t1 == t2 - - def __repr__(self): - L = ["%s=%r" % (key, getattr(self, key, '')) - for key in self.__class__.c.keys()] - return '%s(%s)' % (self.__class__.__name__, ','.join(L)) - - for m in ['__eq__', '__repr__', '__lt__']: - setattr(klass, m, eval(m)) - klass._table = selectable - klass.c = expression.ColumnCollection() - mappr = mapper(klass, - selectable, - extension=AutoAdd(session), - **mapper_kwargs) - - for k in mappr.iterate_properties: - klass.c[k.key] = k.columns[0] - - klass._query = session.query_property() - return klass - -class SqlSoup(object): - """Represent an ORM-wrapped database resource.""" - - def __init__(self, engine_or_metadata, base=object, session=None): - """Initialize a new :class:`.SqlSoup`. - - :param engine_or_metadata: a string database URL, :class:`.Engine` - or :class:`.MetaData` object to associate with. If the - argument is a :class:`.MetaData`, it should be *bound* - to an :class:`.Engine`. - :param base: a class which will serve as the default class for - returned mapped classes. Defaults to ``object``. - :param session: a :class:`.ScopedSession` or :class:`.Session` with - which to associate ORM operations for this :class:`.SqlSoup` instance. - If ``None``, a :class:`.ScopedSession` that's local to this - module is used. - - """ - - self.session = session or Session - self.base=base - - if isinstance(engine_or_metadata, MetaData): - self._metadata = engine_or_metadata - elif isinstance(engine_or_metadata, (basestring, Engine)): - self._metadata = MetaData(engine_or_metadata) - else: - raise ArgumentError("invalid engine or metadata argument %r" % - engine_or_metadata) - - self._cache = {} - self.schema = None - - @property - def bind(self): - """The :class:`.Engine` associated with this :class:`.SqlSoup`.""" - return self._metadata.bind - - engine = bind - - def delete(self, instance): - """Mark an instance as deleted.""" - - self.session.delete(instance) - - def execute(self, stmt, **params): - """Execute a SQL statement. - - The statement may be a string SQL string, - an :func:`.expression.select` construct, or an :func:`.expression.text` - construct. - - """ - return self.session.execute(sql.text(stmt, bind=self.bind), **params) - - @property - def _underlying_session(self): - if isinstance(self.session, session.Session): - return self.session - else: - return self.session() - - def connection(self): - """Return the current :class:`.Connection` in use by the current transaction.""" - - return self._underlying_session._connection_for_bind(self.bind) - - def flush(self): - """Flush pending changes to the database. - - See :meth:`.Session.flush`. - - """ - self.session.flush() - - def rollback(self): - """Rollback the current transaction. - - See :meth:`.Session.rollback`. - - """ - self.session.rollback() - - def commit(self): - """Commit the current transaction. - - See :meth:`.Session.commit`. - - """ - self.session.commit() - - def clear(self): - """Synonym for :meth:`.SqlSoup.expunge_all`.""" - - self.session.expunge_all() - - def expunge(self, instance): - """Remove an instance from the :class:`.Session`. - - See :meth:`.Session.expunge`. - - """ - self.session.expunge(instance) - - def expunge_all(self): - """Clear all objects from the current :class:`.Session`. - - See :meth:`.Session.expunge_all`. - - """ - self.session.expunge_all() - - def map_to(self, attrname, tablename=None, selectable=None, - schema=None, base=None, mapper_args=util.immutabledict()): - """Configure a mapping to the given attrname. - - This is the "master" method that can be used to create any - configuration. - - (new in 0.6.6) - - :param attrname: String attribute name which will be - established as an attribute on this :class:.`.SqlSoup` - instance. - :param base: a Python class which will be used as the - base for the mapped class. If ``None``, the "base" - argument specified by this :class:`.SqlSoup` - instance's constructor will be used, which defaults to - ``object``. - :param mapper_args: Dictionary of arguments which will - be passed directly to :func:`.orm.mapper`. - :param tablename: String name of a :class:`.Table` to be - reflected. If a :class:`.Table` is already available, - use the ``selectable`` argument. This argument is - mutually exclusive versus the ``selectable`` argument. - :param selectable: a :class:`.Table`, :class:`.Join`, or - :class:`.Select` object which will be mapped. This - argument is mutually exclusive versus the ``tablename`` - argument. - :param schema: String schema name to use if the - ``tablename`` argument is present. - - - """ - if attrname in self._cache: - raise InvalidRequestError( - "Attribute '%s' is already mapped to '%s'" % ( - attrname, - class_mapper(self._cache[attrname]).mapped_table - )) - - if tablename is not None: - if not isinstance(tablename, basestring): - raise ArgumentError("'tablename' argument must be a string." - ) - if selectable is not None: - raise ArgumentError("'tablename' and 'selectable' " - "arguments are mutually exclusive") - - selectable = Table(tablename, - self._metadata, - autoload=True, - autoload_with=self.bind, - schema=schema or self.schema) - elif schema: - raise ArgumentError("'tablename' argument is required when " - "using 'schema'.") - elif selectable is not None: - if not isinstance(selectable, expression.FromClause): - raise ArgumentError("'selectable' argument must be a " - "table, select, join, or other " - "selectable construct.") - else: - raise ArgumentError("'tablename' or 'selectable' argument is " - "required.") - - if not selectable.primary_key.columns: - if tablename: - raise PKNotFoundError( - "table '%s' does not have a primary " - "key defined" % tablename) - else: - raise PKNotFoundError( - "selectable '%s' does not have a primary " - "key defined" % selectable) - - mapped_cls = _class_for_table( - self.session, - self.engine, - selectable, - base or self.base, - mapper_args - ) - self._cache[attrname] = mapped_cls - return mapped_cls - - - def map(self, selectable, base=None, **mapper_args): - """Map a selectable directly. - - The class and its mapping are not cached and will - be discarded once dereferenced (as of 0.6.6). - - :param selectable: an :func:`.expression.select` construct. - :param base: a Python class which will be used as the - base for the mapped class. If ``None``, the "base" - argument specified by this :class:`.SqlSoup` - instance's constructor will be used, which defaults to - ``object``. - :param mapper_args: Dictionary of arguments which will - be passed directly to :func:`.orm.mapper`. - - """ - - return _class_for_table( - self.session, - self.engine, - selectable, - base or self.base, - mapper_args - ) - - def with_labels(self, selectable, base=None, **mapper_args): - """Map a selectable directly, wrapping the - selectable in a subquery with labels. - - The class and its mapping are not cached and will - be discarded once dereferenced (as of 0.6.6). - - :param selectable: an :func:`.expression.select` construct. - :param base: a Python class which will be used as the - base for the mapped class. If ``None``, the "base" - argument specified by this :class:`.SqlSoup` - instance's constructor will be used, which defaults to - ``object``. - :param mapper_args: Dictionary of arguments which will - be passed directly to :func:`.orm.mapper`. - - """ - - # TODO give meaningful aliases - return self.map( - expression._clause_element_as_expr(selectable). - select(use_labels=True). - alias('foo'), base=base, **mapper_args) - - def join(self, left, right, onclause=None, isouter=False, - base=None, **mapper_args): - """Create an :func:`.expression.join` and map to it. - - The class and its mapping are not cached and will - be discarded once dereferenced (as of 0.6.6). - - :param left: a mapped class or table object. - :param right: a mapped class or table object. - :param onclause: optional "ON" clause construct.. - :param isouter: if True, the join will be an OUTER join. - :param base: a Python class which will be used as the - base for the mapped class. If ``None``, the "base" - argument specified by this :class:`.SqlSoup` - instance's constructor will be used, which defaults to - ``object``. - :param mapper_args: Dictionary of arguments which will - be passed directly to :func:`.orm.mapper`. - - """ - - j = join(left, right, onclause=onclause, isouter=isouter) - return self.map(j, base=base, **mapper_args) - - def entity(self, attr, schema=None): - """Return the named entity from this :class:`.SqlSoup`, or - create if not present. - - For more generalized mapping, see :meth:`.map_to`. - - """ - try: - return self._cache[attr] - except KeyError, ke: - return self.map_to(attr, tablename=attr, schema=schema) - - def __getattr__(self, attr): - return self.entity(attr) - - def __repr__(self): - return 'SqlSoup(%r)' % self._metadata - diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py new file mode 100644 index 000000000..9ce52beab --- /dev/null +++ b/lib/sqlalchemy/inspection.py @@ -0,0 +1,44 @@ +# sqlalchemy/inspect.py +# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Base inspect API. + +:func:`.inspect` provides access to a contextual object +regarding a subject. + +Various subsections of SQLAlchemy, +such as the :class:`.Inspector`, :class:`.Mapper`, and +others register themselves with the "inspection registry" here +so that they may return a context object given a certain kind +of argument. +""" + +from sqlalchemy import util +_registrars = util.defaultdict(list) + +def inspect(subject): + type_ = type(subject) + for cls in type_.__mro__: + if cls in _registrars: + reg = _registrars[cls] + break + else: + raise exc.InvalidRequestError( + "No inspection system is " + "available for object of type %s" % + type_) + return reg(subject) + +def _inspects(*types): + def decorate(fn_or_cls): + for type_ in types: + if type_ in _registrars: + raise AssertionError( + "Type %s is already " + "registered" % type_) + _registrars[type_] = fn_or_cls + return fn_or_cls + return decorate diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index ca4099d68..d322d426b 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -35,6 +35,7 @@ from sqlalchemy.orm.util import ( outerjoin, polymorphic_union, with_parent, + with_polymorphic, ) from sqlalchemy.orm.properties import ( ColumnProperty, @@ -44,6 +45,11 @@ from sqlalchemy.orm.properties import ( PropertyLoader, SynonymProperty, ) +from sqlalchemy.orm.relationships import ( + foreign, + remote, + remote_foreign +) from sqlalchemy.orm import mapper as mapperlib from sqlalchemy.orm.mapper import reconstructor, validates from sqlalchemy.orm import strategies @@ -81,6 +87,7 @@ __all__ = ( 'dynamic_loader', 'eagerload', 'eagerload_all', + 'foreign', 'immediateload', 'join', 'joinedload', @@ -96,6 +103,8 @@ __all__ = ( 'reconstructor', 'relationship', 'relation', + 'remote', + 'remote_foreign', 'scoped_session', 'sessionmaker', 'subqueryload', diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 3b4f18b31..7625ccead 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -16,66 +16,87 @@ defines a large part of the ORM's interactivity. import operator from operator import itemgetter -from sqlalchemy import util, event, exc as sa_exc +from sqlalchemy import util, event, exc as sa_exc, inspection from sqlalchemy.orm import interfaces, collections, events, exc as orm_exc mapperutil = util.importlater("sqlalchemy.orm", "util") -PASSIVE_NO_RESULT = util.symbol('PASSIVE_NO_RESULT') -ATTR_WAS_SET = util.symbol('ATTR_WAS_SET') -ATTR_EMPTY = util.symbol('ATTR_EMPTY') -NO_VALUE = util.symbol('NO_VALUE') -NEVER_SET = util.symbol('NEVER_SET') - -PASSIVE_RETURN_NEVER_SET = util.symbol('PASSIVE_RETURN_NEVER_SET', -"""Symbol indicating that loader callables can be -fired off, but if no callable is applicable and no value is -present, the attribute should remain non-initialized. -NEVER_SET is returned in this case. -""") - -PASSIVE_NO_INITIALIZE = util.symbol('PASSIVE_NO_INITIALIZE', -"""Symbol indicating that loader callables should - not be fired off, and a non-initialized attribute - should remain that way. -""") - -PASSIVE_NO_FETCH = util.symbol('PASSIVE_NO_FETCH', -"""Symbol indicating that loader callables should not emit SQL, - but a value can be fetched from the current session. - - Non-initialized attributes should be initialized to an empty value. - -""") - -PASSIVE_NO_FETCH_RELATED = util.symbol('PASSIVE_NO_FETCH_RELATED', -"""Symbol indicating that loader callables should not emit SQL for - loading a related object, but can refresh the attributes of the local - instance in order to locate a related object in the current session. - - Non-initialized attributes should be initialized to an empty value. - - The unit of work uses this mode to check if history is present - on many-to-one attributes with minimal SQL emitted. +PASSIVE_NO_RESULT = util.symbol('PASSIVE_NO_RESULT', +"""Symbol returned by a loader callable or other attribute/history +retrieval operation when a value could not be determined, based +on loader callable flags. +""" +) +ATTR_WAS_SET = util.symbol('ATTR_WAS_SET', +"""Symbol returned by a loader callable to indicate the +retrieved value, or values, were assigned to their attributes +on the target object. """) -PASSIVE_ONLY_PERSISTENT = util.symbol('PASSIVE_ONLY_PERSISTENT', -"""Symbol indicating that loader callables should only fire off for - parent objects which are persistent (i.e., have a database - identity). - - Load operations for the "previous" value of an attribute make - use of this flag during change events. - +ATTR_EMPTY = util.symbol('ATTR_EMPTY', +"""Symbol used internally to indicate an attribute had no callable. """) -PASSIVE_OFF = util.symbol('PASSIVE_OFF', -"""Symbol indicating that loader callables should be executed - normally. +NO_VALUE = util.symbol('NO_VALUE', +"""Symbol which may be placed as the 'previous' value of an attribute, +indicating no value was loaded for an attribute when it was modified, +and flags indicated we were not to load it. +""" +) -""") +NEVER_SET = util.symbol('NEVER_SET', +"""Symbol which may be placed as the 'previous' value of an attribute +indicating that the attribute had not been assigned to previously. +""" +) + +NO_CHANGE = util.symbol("NO_CHANGE", +"""No callables or SQL should be emitted on attribute access +and no state should change""", canonical=0 +) + +CALLABLES_OK = util.symbol("CALLABLES_OK", +"""Loader callables can be fired off if a value +is not present.""", canonical=1 +) + +SQL_OK = util.symbol("SQL_OK", +"""Loader callables can emit SQL at least on scalar value +attributes.""", canonical=2) + +RELATED_OBJECT_OK = util.symbol("RELATED_OBJECT_OK", +"""callables can use SQL to load related objects as well +as scalar value attributes. +""", canonical=4 +) + +INIT_OK = util.symbol("INIT_OK", +"""Attributes should be initialized with a blank +value (None or an empty collection) upon get, if no other +value can be obtained. +""", canonical=8 +) + +NON_PERSISTENT_OK = util.symbol("NON_PERSISTENT_OK", +"""callables can be emitted if the parent is not persistent.""", +canonical=16 +) + + +# pre-packaged sets of flags used as inputs +PASSIVE_OFF = RELATED_OBJECT_OK | \ + NON_PERSISTENT_OK | \ + INIT_OK | \ + CALLABLES_OK | \ + SQL_OK + +PASSIVE_RETURN_NEVER_SET = PASSIVE_OFF ^ INIT_OK +PASSIVE_NO_INITIALIZE = PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK +PASSIVE_NO_FETCH = PASSIVE_OFF ^ SQL_OK +PASSIVE_NO_FETCH_RELATED = PASSIVE_OFF ^ RELATED_OBJECT_OK +PASSIVE_ONLY_PERSISTENT = PASSIVE_OFF ^ NON_PERSISTENT_OK class QueryableAttribute(interfaces.PropComparator): @@ -147,6 +168,10 @@ class QueryableAttribute(interfaces.PropComparator): return self.comparator.property +@inspection._inspects(QueryableAttribute) +def _get_prop(source): + return source.property + class InstrumentedAttribute(QueryableAttribute): """Class bound instrumented attribute which adds descriptor methods.""" @@ -184,11 +209,14 @@ def create_proxied_attribute(descriptor): """ - def __init__(self, class_, key, descriptor, comparator, - adapter=None, doc=None): + def __init__(self, class_, key, descriptor, + comparator, + adapter=None, doc=None, + original_property=None): self.class_ = class_ self.key = key self.descriptor = descriptor + self.original_property = original_property self._comparator = comparator self.adapter = adapter self.__doc__ = doc @@ -443,7 +471,7 @@ class AttributeImpl(object): key = self.key if key not in state.committed_state or \ state.committed_state[key] is NEVER_SET: - if passive is PASSIVE_NO_INITIALIZE: + if not passive & CALLABLES_OK: return PASSIVE_NO_RESULT if key in state.callables: @@ -468,7 +496,7 @@ class AttributeImpl(object): elif value is not ATTR_EMPTY: return self.set_committed_value(state, dict_, value) - if passive is PASSIVE_RETURN_NEVER_SET: + if not passive & INIT_OK: return NEVER_SET else: # Return a new, empty value @@ -525,7 +553,7 @@ class ScalarAttributeImpl(AttributeImpl): if self.dispatch.remove: self.fire_remove_event(state, dict_, old, None) - state.modified_event(dict_, self, old) + state._modified_event(dict_, self, old) del dict_[self.key] def get_history(self, state, dict_, passive=PASSIVE_OFF): @@ -545,7 +573,7 @@ class ScalarAttributeImpl(AttributeImpl): if self.dispatch.set: value = self.fire_replace_event(state, dict_, value, old, initiator) - state.modified_event(dict_, self, old) + state._modified_event(dict_, self, old) dict_[self.key] = value def fire_replace_event(self, state, dict_, value, previous, initiator): @@ -562,60 +590,6 @@ class ScalarAttributeImpl(AttributeImpl): self.property.columns[0].type -class MutableScalarAttributeImpl(ScalarAttributeImpl): - """represents a scalar value-holding InstrumentedAttribute, which can - detect changes within the value itself. - - """ - - uses_objects = False - supports_population = True - - def __init__(self, class_, key, callable_, dispatch, - class_manager, copy_function=None, - compare_function=None, **kwargs): - super(ScalarAttributeImpl, self).__init__( - class_, - key, - callable_, dispatch, - compare_function=compare_function, - **kwargs) - class_manager.mutable_attributes.add(key) - if copy_function is None: - raise sa_exc.ArgumentError( - "MutableScalarAttributeImpl requires a copy function") - self.copy = copy_function - - def get_history(self, state, dict_, passive=PASSIVE_OFF): - if not dict_: - v = state.committed_state.get(self.key, NO_VALUE) - else: - v = dict_.get(self.key, NO_VALUE) - - return History.from_scalar_attribute(self, state, v) - - def check_mutable_modified(self, state, dict_): - a, u, d = self.get_history(state, dict_) - return bool(a or d) - - def get(self, state, dict_, passive=PASSIVE_OFF): - if self.key not in state.mutable_dict: - ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive) - if ret is not PASSIVE_NO_RESULT: - state.mutable_dict[self.key] = ret - return ret - else: - return state.mutable_dict[self.key] - - def delete(self, state, dict_): - ScalarAttributeImpl.delete(self, state, dict_) - state.mutable_dict.pop(self.key) - - def set(self, state, dict_, value, initiator, - passive=PASSIVE_OFF, check_old=None, pop=False): - ScalarAttributeImpl.set(self, state, dict_, value, - initiator, passive, check_old=check_old, pop=pop) - state.mutable_dict[self.key] = value class ScalarObjectAttributeImpl(ScalarAttributeImpl): @@ -639,8 +613,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): if self.key in dict_: return History.from_object_attribute(self, state, dict_[self.key]) else: - if passive is PASSIVE_OFF: - passive = PASSIVE_RETURN_NEVER_SET + if passive & INIT_OK: + passive ^= INIT_OK current = self.get(state, dict_, passive=passive) if current is PASSIVE_NO_RESULT: return HISTORY_BLANK @@ -704,7 +678,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): for fn in self.dispatch.remove: fn(state, value, initiator or self) - state.modified_event(dict_, self, value) + state._modified_event(dict_, self, value) def fire_replace_event(self, state, dict_, value, previous, initiator): if self.trackparent: @@ -716,7 +690,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): for fn in self.dispatch.set: value = fn(state, value, previous, initiator or self) - state.modified_event(dict_, self, previous) + state._modified_event(dict_, self, previous) if self.trackparent: if value is not None: @@ -799,7 +773,7 @@ class CollectionAttributeImpl(AttributeImpl): for fn in self.dispatch.append: value = fn(state, value, initiator or self) - state.modified_event(dict_, self, NEVER_SET, True) + state._modified_event(dict_, self, NEVER_SET, True) if self.trackparent and value is not None: self.sethasparent(instance_state(value), state, True) @@ -807,7 +781,7 @@ class CollectionAttributeImpl(AttributeImpl): return value def fire_pre_remove_event(self, state, dict_, initiator): - state.modified_event(dict_, self, NEVER_SET, True) + state._modified_event(dict_, self, NEVER_SET, True) def fire_remove_event(self, state, dict_, value, initiator): if self.trackparent and value is not None: @@ -816,13 +790,13 @@ class CollectionAttributeImpl(AttributeImpl): for fn in self.dispatch.remove: fn(state, value, initiator or self) - state.modified_event(dict_, self, NEVER_SET, True) + state._modified_event(dict_, self, NEVER_SET, True) def delete(self, state, dict_): if self.key not in dict_: return - state.modified_event(dict_, self, NEVER_SET, True) + state._modified_event(dict_, self, NEVER_SET, True) collection = self.get_collection(state, state.dict) collection.clear_with_event() @@ -849,7 +823,7 @@ class CollectionAttributeImpl(AttributeImpl): value = self.fire_append_event(state, dict_, value, initiator) assert self.key not in dict_, \ "Collection was loaded during event handling." - state.get_pending(self.key).append(value) + state._get_pending_mutation(self.key).append(value) else: collection.append_with_event(value, initiator) @@ -862,7 +836,7 @@ class CollectionAttributeImpl(AttributeImpl): self.fire_remove_event(state, dict_, value, initiator) assert self.key not in dict_, \ "Collection was loaded during event handling." - state.get_pending(self.key).remove(value) + state._get_pending_mutation(self.key).remove(value) else: collection.remove_with_event(value, initiator) @@ -918,7 +892,7 @@ class CollectionAttributeImpl(AttributeImpl): return # place a copy of "old" in state.committed_state - state.modified_event(dict_, self, old, True) + state._modified_event(dict_, self, old, True) old_collection = getattr(old, '_sa_adapter') @@ -939,12 +913,12 @@ class CollectionAttributeImpl(AttributeImpl): state.commit(dict_, [self.key]) - if self.key in state.pending: + if self.key in state._pending_mutations: # pending items exist. issue a modified event, # add/remove new items. - state.modified_event(dict_, self, user_data, True) + state._modified_event(dict_, self, user_data, True) - pending = state.pending.pop(self.key) + pending = state._pending_mutations.pop(self.key) added = pending.added_items removed = pending.deleted_items for item in added: @@ -1246,7 +1220,7 @@ def register_attribute(class_, key, **kw): def register_attribute_impl(class_, key, uselist=False, callable_=None, - useobject=False, mutable_scalars=False, + useobject=False, impl_class=None, backref=None, **kw): manager = manager_of_class(class_) @@ -1267,9 +1241,6 @@ def register_attribute_impl(class_, key, elif useobject: impl = ScalarObjectAttributeImpl(class_, key, callable_, dispatch,**kw) - elif mutable_scalars: - impl = MutableScalarAttributeImpl(class_, key, callable_, dispatch, - class_manager=manager, **kw) else: impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw) @@ -1391,5 +1362,5 @@ def flag_modified(instance, key): """ state, dict_ = instance_state(instance), instance_dict(instance) impl = state.manager[key].impl - state.modified_event(dict_, impl, NO_VALUE) + state._modified_event(dict_, impl, NO_VALUE) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 160fac8be..d51d7bcd2 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -111,22 +111,57 @@ import weakref from sqlalchemy.sql import expression from sqlalchemy import schema, util, exc as sa_exc - - __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] __instrumentation_mutex = util.threading.Lock() + +class _PlainColumnGetter(object): + """Plain column getter, stores collection of Column objects + directly. + + Serializes to a :class:`._SerializableColumnGetterV2` + which has more expensive __call__() performance + and some rare caveats. + + """ + def __init__(self, cols): + self.cols = cols + self.composite = len(cols) > 1 + + def __reduce__(self): + return _SerializableColumnGetterV2._reduce_from_cols(self.cols) + + def _cols(self, mapper): + return self.cols + + def __call__(self, value): + state = instance_state(value) + m = _state_mapper(state) + + key = [ + m._get_state_attr_by_column(state, state.dict, col) + for col in self._cols(m) + ] + + if self.composite: + return tuple(key) + else: + return key[0] + class _SerializableColumnGetter(object): + """Column-based getter used in version 0.7.6 only. + + Remains here for pickle compatibility with 0.7.6. + + """ def __init__(self, colkeys): self.colkeys = colkeys self.composite = len(colkeys) > 1 - def __reduce__(self): return _SerializableColumnGetter, (self.colkeys,) - def __call__(self, value): state = instance_state(value) m = _state_mapper(state) @@ -139,6 +174,48 @@ class _SerializableColumnGetter(object): else: return key[0] +class _SerializableColumnGetterV2(_PlainColumnGetter): + """Updated serializable getter which deals with + multi-table mapped classes. + + Two extremely unusual cases are not supported. + Mappings which have tables across multiple metadata + objects, or which are mapped to non-Table selectables + linked across inheriting mappers may fail to function + here. + + """ + + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return self.__class__, (self.colkeys,) + + @classmethod + def _reduce_from_cols(cls, cols): + def _table_key(c): + if not isinstance(c.table, expression.TableClause): + return None + else: + return c.table.key + colkeys = [(c.key, _table_key(c)) for c in cols] + return _SerializableColumnGetterV2, (colkeys,) + + def _cols(self, mapper): + cols = [] + metadata = getattr(mapper.local_table, 'metadata', None) + for (ckey, tkey) in self.colkeys: + if tkey is None or \ + metadata is None or \ + tkey not in metadata: + cols.append(mapper.local_table.c[ckey]) + else: + cols.append(metadata.tables[tkey].c[ckey]) + return cols + + def column_mapped_collection(mapping_spec): """A dictionary-based collection type with column-based keying. @@ -155,10 +232,10 @@ def column_mapped_collection(mapping_spec): from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.attributes import instance_state - cols = [c.key for c in [ - expression._only_column_elements(q, "mapping_spec") - for q in util.to_list(mapping_spec)]] - keyfunc = _SerializableColumnGetter(cols) + cols = [expression._only_column_elements(q, "mapping_spec") + for q in util.to_list(mapping_spec) + ] + keyfunc = _PlainColumnGetter(cols) return lambda: MappedCollection(keyfunc) class _SerializableAttrGetter(object): diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index ed0d4924e..ba1109dfb 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -65,7 +65,8 @@ class DescriptorProperty(MapperProperty): self.key, self.descriptor, lambda: self._comparator_factory(mapper), - doc=self.doc + doc=self.doc, + original_property=self ) proxy_attr.impl = _ProxyImpl(self.key) mapper.class_manager.instrument_attribute(self.key, proxy_attr) diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index edf052870..18fc76aa9 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -57,7 +57,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): self.query_class = mixin_user_query(query_class) def get(self, state, dict_, passive=attributes.PASSIVE_OFF): - if passive is not attributes.PASSIVE_OFF: + if not passive & attributes.SQL_OK: return self._get_collection_history(state, attributes.PASSIVE_NO_INITIALIZE).added_items else: @@ -65,7 +65,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def get_collection(self, state, dict_, user_data=None, passive=attributes.PASSIVE_NO_INITIALIZE): - if passive is not attributes.PASSIVE_OFF: + if not passive & attributes.SQL_OK: return self._get_collection_history(state, passive).added_items else: @@ -97,7 +97,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): if self.key not in state.committed_state: state.committed_state[self.key] = CollectionHistory(self, state) - state.modified_event(dict_, + state._modified_event(dict_, self, attributes.NEVER_SET) @@ -142,7 +142,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): c.deleted_items) def get_all_pending(self, state, dict_): - c = self._get_collection_history(state, True) + c = self._get_collection_history(state, attributes.PASSIVE_NO_INITIALIZE) return [ (attributes.instance_state(x), x) for x in @@ -155,7 +155,9 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: c = CollectionHistory(self, state) - if passive is attributes.PASSIVE_OFF: + # TODO: consider using a different flag here, possibly + # one local to dynamic + if passive & attributes.INIT_OK: return CollectionHistory(self, state, apply_to=c) else: return c diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 59d121de9..bb5fbb6e8 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -10,7 +10,6 @@ from sqlalchemy.orm import attributes class IdentityMap(dict): def __init__(self): - self._mutable_attrs = set() self._modified = set() self._wr = weakref.ref(self) @@ -31,28 +30,18 @@ class IdentityMap(dict): if state.modified: self._modified.add(state) - if state.manager.mutable_attributes: - self._mutable_attrs.add(state) def _manage_removed_state(self, state): del state._instance_dict - self._mutable_attrs.discard(state) self._modified.discard(state) def _dirty_states(self): - return self._modified.union(s for s in self._mutable_attrs.copy() - if s.modified) + return self._modified def check_modified(self): """return True if any InstanceStates present have been marked as 'modified'.""" - if self._modified: - return True - else: - for state in self._mutable_attrs.copy(): - if state.modified: - return True - return False + return bool(self._modified) def has_key(self, key): return key in self diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index af9ef7841..e9d1ca36a 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -23,7 +23,7 @@ An example of full customization is in /examples/custom_attributes. from sqlalchemy.orm import exc, collections, events from operator import attrgetter, itemgetter -from sqlalchemy import event, util +from sqlalchemy import event, util, inspection import weakref from sqlalchemy.orm import state, attributes @@ -86,7 +86,6 @@ class ClassManager(dict): self.factory = None # where we came from, for inheritance bookkeeping self.info = {} self.new_init = None - self.mutable_attributes = set() self.local_attrs = {} self.originals = {} @@ -155,10 +154,7 @@ class ClassManager(dict): @util.memoized_property def _state_constructor(self): self.dispatch.first_init(self, self.class_) - if self.mutable_attributes: - return state.MutableAttrInstanceState - else: - return state.InstanceState + return state.InstanceState def manage(self): """Mark this instance as the manager for its class.""" @@ -209,8 +205,6 @@ class ClassManager(dict): del self.local_attrs[key] self.uninstall_descriptor(key) del self[key] - if key in self.mutable_attributes: - self.mutable_attributes.remove(key) for cls in self.class_.__subclasses__(): manager = manager_of_class(cls) if manager: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e96b7549a..795447763 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -33,6 +33,7 @@ from sqlalchemy.orm.util import _INSTRUMENTOR, _class_to_mapper, \ import sys sessionlib = util.importlater("sqlalchemy.orm", "session") properties = util.importlater("sqlalchemy.orm", "properties") +descriptor_props = util.importlater("sqlalchemy.orm", "descriptor_props") __all__ = ( 'Mapper', @@ -678,9 +679,10 @@ class Mapper(object): self._reconstructor = method event.listen(manager, 'load', _event_on_load, raw=True) elif hasattr(method, '__sa_validators__'): + include_removes = getattr(method, "__sa_include_removes__", False) for name in method.__sa_validators__: self.validators = self.validators.union( - {name : method} + {name : (method, include_removes)} ) manager.info[_INSTRUMENTOR] = self @@ -1356,7 +1358,8 @@ class Mapper(object): spec = self.with_polymorphic[0] if selectable is False: selectable = self.with_polymorphic[1] - + elif selectable is False: + selectable = None mappers = self._mappers_from_spec(spec, selectable) if selectable is not None: return mappers, selectable @@ -1392,12 +1395,35 @@ class Mapper(object): continue yield c - @property - def properties(self): - raise NotImplementedError( - "Public collection of MapperProperty objects is " - "provided by the get_property() and iterate_properties " - "accessors.") + @util.memoized_property + def attr(self): + if _new_mappers: + configure_mappers() + return util.ImmutableProperties(self._props) + + @_memoized_configured_property + def synonyms(self): + return self._filter_properties(descriptor_props.SynonymProperty) + + @_memoized_configured_property + def column_attrs(self): + return self._filter_properties(properties.ColumnProperty) + + @_memoized_configured_property + def relationships(self): + return self._filter_properties(properties.RelationshipProperty) + + @_memoized_configured_property + def composites(self): + return self._filter_properties(descriptor_props.CompositeProperty) + + def _filter_properties(self, type_): + if _new_mappers: + configure_mappers() + return util.ImmutableProperties(dict( + (k, v) for k, v in self._props.iteritems() + if isinstance(v, type_) + )) @_memoized_configured_property def _get_clause(self): @@ -2291,7 +2317,7 @@ def reconstructor(fn): fn.__sa_reconstructor__ = True return fn -def validates(*names): +def validates(*names, **kw): """Decorate a method as a 'validator' for one or more named properties. Designates a method as a validator, a method which receives the @@ -2307,9 +2333,16 @@ def validates(*names): an assertion to avoid recursion overflows. This is a reentrant condition which is not supported. + :param \*names: list of attribute names to be validated. + :param include_removes: if True, "remove" events will be + sent as well - the validation function must accept an additional + argument "is_remove" which will be a boolean. New in 0.7.7. + """ + include_removes = kw.pop('include_removes', False) def wrap(fn): fn.__sa_validators__ = names + fn.__sa_include_removes__ = include_removes return fn return wrap diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 74ccf0157..79676642c 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -14,11 +14,11 @@ mapped attributes. from sqlalchemy import sql, util, log, exc as sa_exc from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ join_condition, _shallow_annotate -from sqlalchemy.sql import operators, expression +from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm import attributes, dependency, mapper, \ - object_mapper, strategies, configure_mappers + object_mapper, strategies, configure_mappers, relationships from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \ - _orm_annotate, _orm_deannotate + _orm_annotate, _orm_deannotate, _orm_full_deannotate from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, \ MapperProperty, ONETOMANY, PropComparator, StrategizedProperty @@ -33,9 +33,9 @@ from descriptor_props import CompositeProperty, SynonymProperty, \ class ColumnProperty(StrategizedProperty): """Describes an object attribute that corresponds to a table column. - + Public constructor is the :func:`.orm.column_property` function. - + """ def __init__(self, *columns, **kwargs): @@ -62,7 +62,7 @@ class ColumnProperty(StrategizedProperty): """ self._orig_columns = [expression._labeled(c) for c in columns] - self.columns = [expression._labeled(_orm_deannotate(c)) + self.columns = [expression._labeled(_orm_full_deannotate(c)) for c in columns] self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) @@ -99,6 +99,13 @@ class ColumnProperty(StrategizedProperty): else: self.strategy_class = strategies.ColumnLoader + @property + def expression(self): + """Return the primary column or expression for this ColumnProperty. + + """ + return self.columns[0] + def instrument_class(self, mapper): if not self.instrument: return @@ -177,13 +184,13 @@ log.class_logger(ColumnProperty) class RelationshipProperty(StrategizedProperty): """Describes an object property that holds a single item or list of items that correspond to a related database table. - + Public constructor is the :func:`.orm.relationship` function. - + Of note here is the :class:`.RelationshipProperty.Comparator` class, which implements comparison operations for scalar- and collection-referencing mapped attributes. - + """ strategy_wildcard_key = 'relationship:*' @@ -276,7 +283,6 @@ class RelationshipProperty(StrategizedProperty): else: self.backref = backref - def instrument_class(self, mapper): attributes.register_descriptor( mapper.class_, @@ -293,7 +299,7 @@ class RelationshipProperty(StrategizedProperty): def __init__(self, prop, mapper, of_type=None, adapter=None): """Construction of :class:`.RelationshipProperty.Comparator` is internal to the ORM's attribute mechanics. - + """ self.prop = prop self.mapper = mapper @@ -332,10 +338,10 @@ class RelationshipProperty(StrategizedProperty): def of_type(self, cls): """Produce a construct that represents a particular 'subtype' of attribute for the parent class. - + Currently this is usable in conjunction with :meth:`.Query.join` and :meth:`.Query.outerjoin`. - + """ return RelationshipProperty.Comparator( self.property, @@ -345,7 +351,7 @@ class RelationshipProperty(StrategizedProperty): def in_(self, other): """Produce an IN clause - this is not implemented for :func:`~.orm.relationship`-based attributes at this time. - + """ raise NotImplementedError('in_() not yet supported for ' 'relationships. For a simple many-to-one, use ' @@ -362,15 +368,15 @@ class RelationshipProperty(StrategizedProperty): this will typically produce a clause such as:: - + mytable.related_id == <some id> - + Where ``<some id>`` is the primary key of the given object. - + The ``==`` operator provides partial functionality for non- many-to-one comparisons: - + * Comparisons against collections are not supported. Use :meth:`~.RelationshipProperty.Comparator.contains`. * Compared to a scalar one-to-many, will produce a @@ -445,6 +451,7 @@ class RelationshipProperty(StrategizedProperty): else: j = _orm_annotate(pj, exclude=self.property.remote_side) + # MARKMARK if criterion is not None and target_adapter: # limit this adapter to annotated only? criterion = target_adapter.traverse(criterion) @@ -465,42 +472,42 @@ class RelationshipProperty(StrategizedProperty): def any(self, criterion=None, **kwargs): """Produce an expression that tests a collection against particular criterion, using EXISTS. - + An expression like:: - + session.query(MyClass).filter( MyClass.somereference.any(SomeRelated.x==2) ) - - + + Will produce a query like:: - + SELECT * FROM my_table WHERE EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id AND related.x=2) - + Because :meth:`~.RelationshipProperty.Comparator.any` uses a correlated subquery, its performance is not nearly as good when compared against large target tables as that of using a join. - + :meth:`~.RelationshipProperty.Comparator.any` is particularly useful for testing for empty collections:: - + session.query(MyClass).filter( ~MyClass.somereference.any() ) - + will produce:: - + SELECT * FROM my_table WHERE NOT EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id) - + :meth:`~.RelationshipProperty.Comparator.any` is only valid for collections, i.e. a :func:`.relationship` that has ``uselist=True``. For scalar references, use :meth:`~.RelationshipProperty.Comparator.has`. - + """ if not self.property.uselist: raise sa_exc.InvalidRequestError( @@ -515,14 +522,14 @@ class RelationshipProperty(StrategizedProperty): particular criterion, using EXISTS. An expression like:: - + session.query(MyClass).filter( MyClass.somereference.has(SomeRelated.x==2) ) - - + + Will produce a query like:: - + SELECT * FROM my_table WHERE EXISTS (SELECT 1 FROM related WHERE related.id==my_table.related_id AND related.x=2) @@ -531,12 +538,12 @@ class RelationshipProperty(StrategizedProperty): a correlated subquery, its performance is not nearly as good when compared against large target tables as that of using a join. - + :meth:`~.RelationshipProperty.Comparator.has` is only valid for scalar references, i.e. a :func:`.relationship` that has ``uselist=False``. For collection references, use :meth:`~.RelationshipProperty.Comparator.any`. - + """ if self.property.uselist: raise sa_exc.InvalidRequestError( @@ -547,44 +554,44 @@ class RelationshipProperty(StrategizedProperty): def contains(self, other, **kwargs): """Return a simple expression that tests a collection for containment of a particular item. - + :meth:`~.RelationshipProperty.Comparator.contains` is only valid for a collection, i.e. a :func:`~.orm.relationship` that implements one-to-many or many-to-many with ``uselist=True``. - + When used in a simple one-to-many context, an expression like:: - + MyClass.contains(other) - + Produces a clause like:: - + mytable.id == <some id> - + Where ``<some id>`` is the value of the foreign key attribute on ``other`` which refers to the primary key of its parent object. From this it follows that :meth:`~.RelationshipProperty.Comparator.contains` is very useful when used with simple one-to-many operations. - + For many-to-many operations, the behavior of :meth:`~.RelationshipProperty.Comparator.contains` has more caveats. The association table will be rendered in the statement, producing an "implicit" join, that is, includes multiple tables in the FROM clause which are equated in the WHERE clause:: - + query(MyClass).filter(MyClass.contains(other)) - + Produces a query like:: - + SELECT * FROM my_table, my_association_table AS my_association_table_1 WHERE my_table.id = my_association_table_1.parent_id AND my_association_table_1.child_id = <some id> - + Where ``<some id>`` would be the primary key of ``other``. From the above, it is clear that :meth:`~.RelationshipProperty.Comparator.contains` @@ -598,7 +605,7 @@ class RelationshipProperty(StrategizedProperty): a less-performant alternative using EXISTS, or refer to :meth:`.Query.outerjoin` as well as :ref:`ormtutorial_joins` for more details on constructing outer joins. - + """ if not self.property.uselist: raise sa_exc.InvalidRequestError( @@ -649,19 +656,19 @@ class RelationshipProperty(StrategizedProperty): """Implement the ``!=`` operator. In a many-to-one context, such as:: - + MyClass.some_prop != <some object> - + This will typically produce a clause such as:: - + mytable.related_id != <some id> - + Where ``<some id>`` is the primary key of the given object. - + The ``!=`` operator provides partial functionality for non- many-to-one comparisons: - + * Comparisons against collections are not supported. Use :meth:`~.RelationshipProperty.Comparator.contains` @@ -682,7 +689,7 @@ class RelationshipProperty(StrategizedProperty): membership tests. * Comparisons against ``None`` given in a one-to-many or many-to-many context produce an EXISTS clause. - + """ if isinstance(other, (NoneType, expression._Null)): if self.property.direction == MANYTOONE: @@ -804,6 +811,27 @@ class RelationshipProperty(StrategizedProperty): dest_state.get_impl(self.key).set(dest_state, dest_dict, obj, None) + def _value_as_iterable(self, state, dict_, key, + passive=attributes.PASSIVE_OFF): + """Return a list of tuples (state, obj) for the given + key. + + returns an empty list if the value is None/empty/PASSIVE_NO_RESULT + """ + + impl = state.manager[key].impl + x = impl.get(state, dict_, passive=passive) + if x is attributes.PASSIVE_NO_RESULT or x is None: + return [] + elif hasattr(impl, 'get_collection'): + return [ + (attributes.instance_state(o), o) for o in + impl.get_collection(state, dict_, x, passive=passive) + ] + else: + return [(attributes.instance_state(x), x)] + + def cascade_iterator(self, type_, state, dict_, visited_states, halt_on=None): #assert type_ in self.cascade @@ -818,7 +846,7 @@ class RelationshipProperty(StrategizedProperty): get_all_pending(state, dict_) else: - tuples = state.value_as_iterable(dict_, self.key, + tuples = self._value_as_iterable(state, dict_, self.key, passive=passive) skip_pending = type_ == 'refresh-expire' and 'delete-orphan' \ @@ -880,9 +908,9 @@ class RelationshipProperty(StrategizedProperty): def mapper(self): """Return the targeted :class:`.Mapper` for this :class:`.RelationshipProperty`. - + This is a lazy-initializing static attribute. - + """ if isinstance(self.argument, type): mapper_ = mapper.class_mapper(self.argument, @@ -914,58 +942,25 @@ class RelationshipProperty(StrategizedProperty): def do_init(self): self._check_conflicts() self._process_dependent_arguments() - self._determine_joins() - self._determine_synchronize_pairs() - self._determine_direction() - self._determine_local_remote_pairs() + self._setup_join_conditions() + self._check_cascade_settings() self._post_init() self._generate_backref() super(RelationshipProperty, self).do_init() - def _check_conflicts(self): - """Test that this relationship is legal, warn about - inheritance conflicts.""" - - if not self.is_primary() \ - and not mapper.class_mapper( - self.parent.class_, - compile=False).has_property(self.key): - raise sa_exc.ArgumentError("Attempting to assign a new " - "relationship '%s' to a non-primary mapper on " - "class '%s'. New relationships can only be added " - "to the primary mapper, i.e. the very first mapper " - "created for class '%s' " % (self.key, - self.parent.class_.__name__, - self.parent.class_.__name__)) - - # check for conflicting relationship() on superclass - if not self.parent.concrete: - for inheriting in self.parent.iterate_to_root(): - if inheriting is not self.parent \ - and inheriting.has_property(self.key): - util.warn("Warning: relationship '%s' on mapper " - "'%s' supersedes the same relationship " - "on inherited mapper '%s'; this can " - "cause dependency issues during flush" - % (self.key, self.parent, inheriting)) - def _process_dependent_arguments(self): """Convert incoming configuration arguments to their proper form. - + Callables are resolved, ORM annotations removed. - + """ # accept callables for other attributes which may require # deferred initialization. This technique is used # by declarative "string configs" and some recipes. for attr in ( - 'order_by', - 'primaryjoin', - 'secondaryjoin', - 'secondary', - '_user_defined_foreign_keys', - 'remote_side', + 'order_by', 'primaryjoin', 'secondaryjoin', + 'secondary', '_user_defined_foreign_keys', 'remote_side', ): attr_value = getattr(self, attr) if util.callable(attr_value): @@ -1008,288 +1003,69 @@ class RelationshipProperty(StrategizedProperty): (self.key, self.parent.class_) ) - def _determine_joins(self): - """Determine the 'primaryjoin' and 'secondaryjoin' attributes, - if not passed to the constructor already. - - This is based on analysis of the foreign key relationships - between the parent and target mapped selectables. - - """ - if self.secondaryjoin is not None and self.secondary is None: - raise sa_exc.ArgumentError("Property '" + self.key - + "' specified with secondary join condition but " - "no secondary argument") - - # if join conditions were not specified, figure them out based - # on foreign keys - - def _search_for_join(mapper, table): - # find a join between the given mapper's mapped table and - # the given table. will try the mapper's local table first - # for more specificity, then if not found will try the more - # general mapped table, which in the case of inheritance is - # a join. - return join_condition(mapper.mapped_table, table, - a_subset=mapper.local_table) - - try: - if self.secondary is not None: - if self.secondaryjoin is None: - self.secondaryjoin = _search_for_join(self.mapper, - self.secondary) - if self.primaryjoin is None: - self.primaryjoin = _search_for_join(self.parent, - self.secondary) - else: - if self.primaryjoin is None: - self.primaryjoin = _search_for_join(self.parent, - self.target) - except sa_exc.ArgumentError, e: - raise sa_exc.ArgumentError("Could not determine join " - "condition between parent/child tables on " - "relationship %s. Specify a 'primaryjoin' " - "expression. If 'secondary' is present, " - "'secondaryjoin' is needed as well." - % self) + def _setup_join_conditions(self): + self._join_condition = jc = relationships.JoinCondition( + parent_selectable=self.parent.mapped_table, + child_selectable=self.mapper.mapped_table, + parent_local_selectable=self.parent.local_table, + child_local_selectable=self.mapper.local_table, + primaryjoin=self.primaryjoin, + secondary=self.secondary, + secondaryjoin=self.secondaryjoin, + parent_equivalents=self.parent._equivalent_columns, + child_equivalents=self.mapper._equivalent_columns, + consider_as_foreign_keys=self._user_defined_foreign_keys, + local_remote_pairs=self.local_remote_pairs, + remote_side=self.remote_side, + self_referential=self._is_self_referential, + prop=self, + support_sync=not self.viewonly, + can_be_synced_fn=self._columns_are_mapped + ) + self.primaryjoin = jc.deannotated_primaryjoin + self.secondaryjoin = jc.deannotated_secondaryjoin + self.direction = jc.direction + self.local_remote_pairs = jc.local_remote_pairs + self.remote_side = jc.remote_columns + self.local_columns = jc.local_columns + self.synchronize_pairs = jc.synchronize_pairs + self._calculated_foreign_keys = jc.foreign_key_columns + self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs - def _columns_are_mapped(self, *cols): - """Return True if all columns in the given collection are - mapped by the tables referenced by this :class:`.Relationship`. - - """ - for c in cols: - if self.secondary is not None \ - and self.secondary.c.contains_column(c): - continue - if not self.parent.mapped_table.c.contains_column(c) and \ - not self.target.c.contains_column(c): - return False - return True - - def _sync_pairs_from_join(self, join_condition, primary): - """Determine a list of "source"/"destination" column pairs - based on the given join condition, as well as the - foreign keys argument. - - "source" would be a column referenced by a foreign key, - and "destination" would be the column who has a foreign key - reference to "source". - - """ - - fks = self._user_defined_foreign_keys - # locate pairs - eq_pairs = criterion_as_pairs(join_condition, - consider_as_foreign_keys=fks, - any_operator=self.viewonly) - - # couldn't find any fks, but we have - # "secondary" - assume the "secondary" columns - # are the fks - if not eq_pairs and \ - self.secondary is not None and \ - not fks: - fks = set(self.secondary.c) - eq_pairs = criterion_as_pairs(join_condition, - consider_as_foreign_keys=fks, - any_operator=self.viewonly) - - if eq_pairs: - util.warn("No ForeignKey objects were present " - "in secondary table '%s'. Assumed referenced " - "foreign key columns %s for join condition '%s' " - "on relationship %s" % ( - self.secondary.description, - ", ".join(sorted(["'%s'" % col for col in fks])), - join_condition, - self - )) - - # Filter out just to columns that are mapped. - # If viewonly, allow pairs where the FK col - # was part of "foreign keys" - the column it references - # may be in an un-mapped table - see - # test.orm.test_relationships.ViewOnlyComplexJoin.test_basic - # for an example of this. - eq_pairs = [(l, r) for (l, r) in eq_pairs - if self._columns_are_mapped(l, r) - or self.viewonly and - r in fks] - - if eq_pairs: - return eq_pairs - - # from here below is just determining the best error message - # to report. Check for a join condition using any operator - # (not just ==), perhaps they need to turn on "viewonly=True". - if not self.viewonly and criterion_as_pairs(join_condition, - consider_as_foreign_keys=self._user_defined_foreign_keys, - any_operator=True): - - err = "Could not locate any "\ - "foreign-key-equated, locally mapped column "\ - "pairs for %s "\ - "condition '%s' on relationship %s." % ( - primary and 'primaryjoin' or 'secondaryjoin', - join_condition, - self - ) - - if not self._user_defined_foreign_keys: - err += " Ensure that the "\ - "referencing Column objects have a "\ - "ForeignKey present, or are otherwise part "\ - "of a ForeignKeyConstraint on their parent "\ - "Table, or specify the foreign_keys parameter "\ - "to this relationship." - - err += " For more "\ - "relaxed rules on join conditions, the "\ - "relationship may be marked as viewonly=True." - - raise sa_exc.ArgumentError(err) - else: - if self._user_defined_foreign_keys: - raise sa_exc.ArgumentError("Could not determine " - "relationship direction for %s condition " - "'%s', on relationship %s, using manual " - "'foreign_keys' setting. Do the columns " - "in 'foreign_keys' represent all, and " - "only, the 'foreign' columns in this join " - "condition? Does the %s Table already " - "have adequate ForeignKey and/or " - "ForeignKeyConstraint objects established " - "(in which case 'foreign_keys' is usually " - "unnecessary)?" - % ( - primary and 'primaryjoin' or 'secondaryjoin', - join_condition, - self, - primary and 'mapped' or 'secondary' - )) - else: - raise sa_exc.ArgumentError("Could not determine " - "relationship direction for %s condition " - "'%s', on relationship %s. Ensure that the " - "referencing Column objects have a " - "ForeignKey present, or are otherwise part " - "of a ForeignKeyConstraint on their parent " - "Table, or specify the foreign_keys parameter " - "to this relationship." - % ( - primary and 'primaryjoin' or 'secondaryjoin', - join_condition, - self - )) - - def _determine_synchronize_pairs(self): - """Resolve 'primary'/foreign' column pairs from the primaryjoin - and secondaryjoin arguments. - - """ - if self.local_remote_pairs: - if not self._user_defined_foreign_keys: - raise sa_exc.ArgumentError( - "foreign_keys argument is " - "required with _local_remote_pairs argument") - self.synchronize_pairs = [] - for l, r in self.local_remote_pairs: - if r in self._user_defined_foreign_keys: - self.synchronize_pairs.append((l, r)) - elif l in self._user_defined_foreign_keys: - self.synchronize_pairs.append((r, l)) - else: - self.synchronize_pairs = self._sync_pairs_from_join( - self.primaryjoin, - True) - - self._calculated_foreign_keys = util.column_set( - r for (l, r) in - self.synchronize_pairs) - - if self.secondaryjoin is not None: - self.secondary_synchronize_pairs = self._sync_pairs_from_join( - self.secondaryjoin, - False) - self._calculated_foreign_keys.update( - r for (l, r) in - self.secondary_synchronize_pairs) - else: - self.secondary_synchronize_pairs = None - - def _determine_direction(self): - """Determine if this relationship is one to many, many to one, - many to many. - - This is derived from the primaryjoin, presence of "secondary", - and in the case of self-referential the "remote side". - - """ - if self.secondaryjoin is not None: - self.direction = MANYTOMANY - elif self._refers_to_parent_table(): - - # self referential defaults to ONETOMANY unless the "remote" - # side is present and does not reference any foreign key - # columns - - if self.local_remote_pairs: - remote = [r for (l, r) in self.local_remote_pairs] - elif self.remote_side: - remote = self.remote_side - else: - remote = None - if not remote or self._calculated_foreign_keys.difference(l for (l, - r) in self.synchronize_pairs).intersection(remote): - self.direction = ONETOMANY - else: - self.direction = MANYTOONE - else: - parentcols = util.column_set(self.parent.mapped_table.c) - targetcols = util.column_set(self.mapper.mapped_table.c) - - # fk collection which suggests ONETOMANY. - onetomany_fk = targetcols.intersection( - self._calculated_foreign_keys) - - # fk collection which suggests MANYTOONE. + def _check_conflicts(self): + """Test that this relationship is legal, warn about + inheritance conflicts.""" - manytoone_fk = parentcols.intersection( - self._calculated_foreign_keys) + if not self.is_primary() \ + and not mapper.class_mapper( + self.parent.class_, + compile=False).has_property(self.key): + raise sa_exc.ArgumentError("Attempting to assign a new " + "relationship '%s' to a non-primary mapper on " + "class '%s'. New relationships can only be added " + "to the primary mapper, i.e. the very first mapper " + "created for class '%s' " % (self.key, + self.parent.class_.__name__, + self.parent.class_.__name__)) - if onetomany_fk and manytoone_fk: - # fks on both sides. do the same test only based on the - # local side. - referents = [c for (c, f) in self.synchronize_pairs] - onetomany_local = parentcols.intersection(referents) - manytoone_local = targetcols.intersection(referents) + # check for conflicting relationship() on superclass + if not self.parent.concrete: + for inheriting in self.parent.iterate_to_root(): + if inheriting is not self.parent \ + and inheriting.has_property(self.key): + util.warn("Warning: relationship '%s' on mapper " + "'%s' supersedes the same relationship " + "on inherited mapper '%s'; this can " + "cause dependency issues during flush" + % (self.key, self.parent, inheriting)) - if onetomany_local and not manytoone_local: - self.direction = ONETOMANY - elif manytoone_local and not onetomany_local: - self.direction = MANYTOONE - else: - raise sa_exc.ArgumentError( - "Can't determine relationship" - " direction for relationship '%s' - foreign " - "key columns are present in both the parent " - "and the child's mapped tables. Specify " - "'foreign_keys' argument." % self) - elif onetomany_fk: - self.direction = ONETOMANY - elif manytoone_fk: - self.direction = MANYTOONE - else: - raise sa_exc.ArgumentError("Can't determine relationship " - "direction for relationship '%s' - foreign " - "key columns are present in neither the parent " - "nor the child's mapped tables" % self) + def _check_cascade_settings(self): if self.cascade.delete_orphan and not self.single_parent \ and (self.direction is MANYTOMANY or self.direction is MANYTOONE): - util.warn('On %s, delete-orphan cascade is not supported ' + raise sa_exc.ArgumentError( + 'On %s, delete-orphan cascade is not supported ' 'on a many-to-many or many-to-one relationship ' 'when single_parent is not set. Set ' 'single_parent=True on the relationship().' @@ -1300,84 +1076,24 @@ class RelationshipProperty(StrategizedProperty): "relationships only." % self) - def _determine_local_remote_pairs(self): - """Determine pairs of columns representing "local" to - "remote", where "local" columns are on the parent mapper, - "remote" are on the target mapper. - - These pairs are used on the load side only to generate - lazy loading clauses. + def _columns_are_mapped(self, *cols): + """Return True if all columns in the given collection are + mapped by the tables referenced by this :class:`.Relationship`. """ - if not self.local_remote_pairs and not self.remote_side: - # the most common, trivial case. Derive - # local/remote pairs from the synchronize pairs. - eq_pairs = util.unique_list( - self.synchronize_pairs + - (self.secondary_synchronize_pairs or [])) - if self.direction is MANYTOONE: - self.local_remote_pairs = [(r, l) for l, r in eq_pairs] - else: - self.local_remote_pairs = eq_pairs - - # "remote_side" specified, derive from the primaryjoin - # plus remote_side, similarly to how synchronize_pairs - # were determined. - elif self.remote_side: - if self.local_remote_pairs: - raise sa_exc.ArgumentError('remote_side argument is ' - 'redundant against more detailed ' - '_local_remote_side argument.') - if self.direction is MANYTOONE: - self.local_remote_pairs = [(r, l) for (l, r) in - criterion_as_pairs(self.primaryjoin, - consider_as_referenced_keys=self.remote_side, - any_operator=True)] - - else: - self.local_remote_pairs = \ - criterion_as_pairs(self.primaryjoin, - consider_as_foreign_keys=self.remote_side, - any_operator=True) - if not self.local_remote_pairs: - raise sa_exc.ArgumentError('Relationship %s could ' - 'not determine any local/remote column ' - 'pairs from remote side argument %r' - % (self, self.remote_side)) - # else local_remote_pairs were sent explcitly via - # ._local_remote_pairs. - - # create local_side/remote_side accessors - self.local_side = util.ordered_column_set( - l for l, r in self.local_remote_pairs) - self.remote_side = util.ordered_column_set( - r for l, r in self.local_remote_pairs) - - # check that the non-foreign key column in the local/remote - # collection is mapped. The foreign key - # which the individual mapped column references directly may - # itself be in a non-mapped table; see - # test.orm.test_relationships.ViewOnlyComplexJoin.test_basic - # for an example of this. - if self.direction is ONETOMANY: - for col in self.local_side: - if not self._columns_are_mapped(col): - raise sa_exc.ArgumentError( - "Local column '%s' is not " - "part of mapping %s. Specify remote_side " - "argument to indicate which column lazy join " - "condition should compare against." % (col, - self.parent)) - elif self.direction is MANYTOONE: - for col in self.remote_side: - if not self._columns_are_mapped(col): - raise sa_exc.ArgumentError( - "Remote column '%s' is not " - "part of mapping %s. Specify remote_side " - "argument to indicate which column lazy join " - "condition should bind." % (col, self.mapper)) + for c in cols: + if self.secondary is not None \ + and self.secondary.c.contains_column(c): + continue + if not self.parent.mapped_table.c.contains_column(c) and \ + not self.target.c.contains_column(c): + return False + return True def _generate_backref(self): + """Interpret the 'backref' instruction to create a + :func:`.relationship` complementary to this one.""" + if not self.is_primary(): return if self.backref is not None and not self.back_populates: @@ -1391,17 +1107,27 @@ class RelationshipProperty(StrategizedProperty): "'%s' on relationship '%s': property of that " "name exists on mapper '%s'" % (backref_key, self, mapper)) + + # determine primaryjoin/secondaryjoin for the + # backref. Use the one we had, so that + # a custom join doesn't have to be specified in + # both directions. if self.secondary is not None: - pj = kwargs.pop('primaryjoin', self.secondaryjoin) - sj = kwargs.pop('secondaryjoin', self.primaryjoin) + # for many to many, just switch primaryjoin/ + # secondaryjoin. use the annotated + # pj/sj on the _join_condition. + pj = kwargs.pop('primaryjoin', self._join_condition.secondaryjoin) + sj = kwargs.pop('secondaryjoin', self._join_condition.primaryjoin) else: - pj = kwargs.pop('primaryjoin', self.primaryjoin) + pj = kwargs.pop('primaryjoin', + self._join_condition.primaryjoin_reverse_remote) sj = kwargs.pop('secondaryjoin', None) if sj: raise sa_exc.InvalidRequestError( - "Can't assign 'secondaryjoin' on a backref against " - "a non-secondary relationship." - ) + "Can't assign 'secondaryjoin' on a backref " + "against a non-secondary relationship." + ) + foreign_keys = kwargs.pop('foreign_keys', self._user_defined_foreign_keys) parent = self.parent.primary_mapper() @@ -1410,35 +1136,17 @@ class RelationshipProperty(StrategizedProperty): kwargs.setdefault('passive_updates', self.passive_updates) self.back_populates = backref_key relationship = RelationshipProperty( - parent, - self.secondary, - pj, - sj, + parent, self.secondary, + pj, sj, foreign_keys=foreign_keys, back_populates=self.key, - **kwargs - ) + **kwargs) mapper._configure_property(backref_key, relationship) if self.back_populates: self._add_reverse_property(self.back_populates) def _post_init(self): - self.logger.info('%s setup primary join %s', self, - self.primaryjoin) - self.logger.info('%s setup secondary join %s', self, - self.secondaryjoin) - self.logger.info('%s synchronize pairs [%s]', self, - ','.join('(%s => %s)' % (l, r) for (l, r) in - self.synchronize_pairs)) - self.logger.info('%s secondary synchronize pairs [%s]', self, - ','.join('(%s => %s)' % (l, r) for (l, r) in - self.secondary_synchronize_pairs or [])) - self.logger.info('%s local/remote pairs [%s]', self, - ','.join('(%s / %s)' % (l, r) for (l, r) in - self.local_remote_pairs)) - self.logger.info('%s relationship direction %s', self, - self.direction) if self.uselist is None: self.uselist = self.direction is not MANYTOONE if not self.viewonly: @@ -1453,20 +1161,6 @@ class RelationshipProperty(StrategizedProperty): strategy = self._get_strategy(strategies.LazyLoader) return strategy.use_get - def _refers_to_parent_table(self): - pt = self.parent.mapped_table - mt = self.mapper.mapped_table - for c, f in self.synchronize_pairs: - if ( - pt.is_derived_from(c.table) and \ - pt.is_derived_from(f.table) and \ - mt.is_derived_from(c.table) and \ - mt.is_derived_from(f.table) - ): - return True - else: - return False - @util.memoized_property def _is_self_referential(self): return self.mapper.common_parent(self.parent) @@ -1496,75 +1190,22 @@ class RelationshipProperty(StrategizedProperty): else: aliased = True - # place a barrier on the destination such that - # replacement traversals won't ever dig into it. - # its internal structure remains fixed - # regardless of context. - dest_selectable = _shallow_annotate( - dest_selectable, - {'no_replacement_traverse':True}) - - aliased = aliased or (source_selectable is not None) - - primaryjoin, secondaryjoin, secondary = self.primaryjoin, \ - self.secondaryjoin, self.secondary - - # adjust the join condition for single table inheritance, - # in the case that the join is to a subclass - # this is analogous to the "_adjust_for_single_table_inheritance()" - # method in Query. - dest_mapper = of_type or self.mapper single_crit = dest_mapper._single_table_criterion - if single_crit is not None: - if secondaryjoin is not None: - secondaryjoin = secondaryjoin & single_crit - else: - primaryjoin = primaryjoin & single_crit - - if aliased: - if secondary is not None: - secondary = secondary.alias() - primary_aliasizer = ClauseAdapter(secondary) - secondary_aliasizer = \ - ClauseAdapter(dest_selectable, - equivalents=self.mapper._equivalent_columns).\ - chain(primary_aliasizer) - if source_selectable is not None: - primary_aliasizer = \ - ClauseAdapter(secondary).\ - chain(ClauseAdapter(source_selectable, - equivalents=self.parent._equivalent_columns)) - secondaryjoin = \ - secondary_aliasizer.traverse(secondaryjoin) - else: - primary_aliasizer = ClauseAdapter(dest_selectable, - exclude=self.local_side, - equivalents=self.mapper._equivalent_columns) - if source_selectable is not None: - primary_aliasizer.chain( - ClauseAdapter(source_selectable, - exclude=self.remote_side, - equivalents=self.parent._equivalent_columns)) - secondary_aliasizer = None - primaryjoin = primary_aliasizer.traverse(primaryjoin) - target_adapter = secondary_aliasizer or primary_aliasizer - target_adapter.include = target_adapter.exclude = None - else: - target_adapter = None + aliased = aliased or (source_selectable is not None) + + primaryjoin, secondaryjoin, secondary, target_adapter, dest_selectable = \ + self._join_condition.join_targets( + source_selectable, dest_selectable, aliased, single_crit + ) if source_selectable is None: source_selectable = self.parent.local_table if dest_selectable is None: dest_selectable = self.mapper.local_table - return ( - primaryjoin, - secondaryjoin, - source_selectable, - dest_selectable, - secondary, - target_adapter, - ) + return (primaryjoin, secondaryjoin, source_selectable, + dest_selectable, secondary, target_adapter) + PropertyLoader = RelationProperty = RelationshipProperty log.class_logger(RelationshipProperty) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 66d7f6eb4..dda231e0c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -31,6 +31,7 @@ from sqlalchemy.orm import ( ) from sqlalchemy.orm.util import ( AliasedClass, ORMAdapter, _entity_descriptor, _entity_info, + _extended_entity_info, _is_aliased_class, _is_mapped_class, _orm_columns, _orm_selectable, join as orm_join,with_parent, _attr_as_key, aliased ) @@ -92,6 +93,7 @@ class Query(object): _from_obj = () _join_entities = () _select_from_entity = None + _mapper_adapter_map = {} _filter_aliases = None _from_obj_alias = None _joinpath = _joinpoint = util.immutabledict() @@ -114,50 +116,43 @@ class Query(object): for ent in util.to_list(entities): entity_wrapper(self, ent) - self._setup_aliasizers(self._entities) + self._set_entity_selectables(self._entities) - def _setup_aliasizers(self, entities): - if hasattr(self, '_mapper_adapter_map'): - # usually safe to share a single map, but copying to prevent - # subtle leaks if end-user is reusing base query with arbitrary - # number of aliased() objects - self._mapper_adapter_map = d = self._mapper_adapter_map.copy() - else: - self._mapper_adapter_map = d = {} + def _set_entity_selectables(self, entities): + self._mapper_adapter_map = d = self._mapper_adapter_map.copy() for ent in entities: for entity in ent.entities: if entity not in d: - mapper, selectable, is_aliased_class = \ - _entity_info(entity) + mapper, selectable, \ + is_aliased_class, with_polymorphic_mappers, \ + with_polymorphic_discriminator = \ + _extended_entity_info(entity) if not is_aliased_class and mapper.with_polymorphic: - with_polymorphic = mapper._with_polymorphic_mappers if mapper.mapped_table not in \ self._polymorphic_adapters: self._mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter( selectable, mapper._equivalent_columns)) - adapter = None + aliased_adapter = None elif is_aliased_class: - adapter = sql_util.ColumnAdapter( + aliased_adapter = sql_util.ColumnAdapter( selectable, mapper._equivalent_columns) - with_polymorphic = None else: - with_polymorphic = adapter = None + aliased_adapter = None - d[entity] = (mapper, adapter, selectable, - is_aliased_class, with_polymorphic) + d[entity] = (mapper, aliased_adapter, selectable, + is_aliased_class, with_polymorphic_mappers, + with_polymorphic_discriminator) ent.setup_entity(entity, *d[entity]) def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers: self._polymorphic_adapters[m2] = adapter for m in m2.iterate_to_root(): - self._polymorphic_adapters[m.mapped_table] = \ - self._polymorphic_adapters[m.local_table] = \ - adapter + self._polymorphic_adapters[m.local_table] = adapter def _set_select_from(self, *obj): @@ -180,10 +175,9 @@ class Query(object): for m2 in mapper._with_polymorphic_mappers: self._polymorphic_adapters.pop(m2, None) for m in m2.iterate_to_root(): - self._polymorphic_adapters.pop(m.mapped_table, None) self._polymorphic_adapters.pop(m.local_table, None) - def __adapt_polymorphic_element(self, element): + def _adapt_polymorphic_element(self, element): if isinstance(element, expression.FromClause): search = element elif hasattr(element, 'table'): @@ -241,7 +235,7 @@ class Query(object): if self._polymorphic_adapters: adapters.append( ( - orm_only, self.__adapt_polymorphic_element + orm_only, self._adapt_polymorphic_element ) ) @@ -617,35 +611,29 @@ class Query(object): @_generative(_no_clauseelement_condition) def with_polymorphic(self, cls_or_mappers, - selectable=None, discriminator=None): - """Load columns for descendant mappers of this Query's mapper. - - Using this method will ensure that each descendant mapper's - tables are included in the FROM clause, and will allow filter() - criterion to be used against those tables. The resulting - instances will also have those columns already loaded so that - no "post fetch" of those columns will be required. - - :param cls_or_mappers: a single class or mapper, or list of - class/mappers, which inherit from this Query's mapper. - Alternatively, it may also be the string ``'*'``, in which case - all descending mappers will be added to the FROM clause. - - :param selectable: a table or select() statement that will - be used in place of the generated FROM clause. This argument is - required if any of the desired mappers use concrete table - inheritance, since SQLAlchemy currently cannot generate UNIONs - among tables automatically. If used, the ``selectable`` argument - must represent the full set of tables and columns mapped by every - desired mapper. Otherwise, the unaccounted mapped columns will - result in their table being appended directly to the FROM clause - which will usually lead to incorrect results. - - :param discriminator: a column to be used as the "discriminator" - column for the given selectable. If not given, the polymorphic_on - attribute of the mapper will be used, if any. This is useful for - mappers that don't have polymorphic loading behavior by default, - such as concrete table mappers. + selectable=None, + polymorphic_on=None): + """Load columns for inheriting classes. + + :meth:`.Query.with_polymorphic` applies transformations + to the "main" mapped class represented by this :class:`.Query`. + The "main" mapped class here means the :class:`.Query` + object's first argument is a full class, i.e. ``session.query(SomeClass)``. + These transformations allow additional tables to be present + in the FROM clause so that columns for a joined-inheritance + subclass are available in the query, both for the purposes + of load-time efficiency as well as the ability to use + these columns at query time. + + See the documentation section :ref:`with_polymorphic` for + details on how this method is used. + + As of 0.8, a new and more flexible function + :func:`.orm.with_polymorphic` supersedes + :meth:`.Query.with_polymorphic`, as it can apply the equivalent + functionality to any set of columns or classes in the + :class:`.Query`, not just the "zero mapper". See that + function for a description of arguments. """ @@ -657,7 +645,7 @@ class Query(object): entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, - discriminator=discriminator) + polymorphic_on=polymorphic_on) @_generative() def yield_per(self, count): @@ -764,7 +752,7 @@ class Query(object): not mapper.always_refresh and \ self._lockmode is None: - instance = self._get_from_identity(self.session, key, False) + instance = self._get_from_identity(self.session, key, attributes.PASSIVE_OFF) if instance is not None: # reject calls for id in identity map but class # mismatch. @@ -881,7 +869,7 @@ class Query(object): self._entities = list(self._entities) m = _MapperEntity(self, entity) - self._setup_aliasizers([m]) + self._set_entity_selectables([m]) @_generative() def with_session(self, session): @@ -998,7 +986,7 @@ class Query(object): _ColumnEntity(self, c) # _ColumnEntity may add many entities if the # given arg is a FROM clause - self._setup_aliasizers(self._entities[l:]) + self._set_entity_selectables(self._entities[l:]) @util.pending_deprecation("0.7", ":meth:`.add_column` is superseded by :meth:`.add_columns`", @@ -2426,10 +2414,10 @@ class Query(object): # expired - ensure it still exists if state.expired: - if passive is attributes.PASSIVE_NO_FETCH: + if not passive & attributes.SQL_OK: # TODO: no coverage here return attributes.PASSIVE_NO_RESULT - elif passive is attributes.PASSIVE_NO_FETCH_RELATED: + elif not passive & attributes.RELATED_OBJECT_OK: # this mode is used within a flush and the instance's # expired state will be checked soon enough, if necessary return instance @@ -2907,7 +2895,6 @@ class Query(object): context.whereclause, from_obj=froms, use_labels=labels, - correlate=False, # TODO: this order_by is only needed if # LIMIT/OFFSET is present in self._select_args, # else the application on the outside is enough @@ -2973,7 +2960,6 @@ class Query(object): from_obj=froms, use_labels=labels, for_update=for_update, - correlate=False, order_by=context.order_by, **self._select_args ) @@ -3000,7 +2986,7 @@ class Query(object): selected from the total results. """ - for entity, (mapper, adapter, s, i, w) in \ + for entity, (mapper, adapter, s, i, w, d) in \ self._mapper_adapter_map.iteritems(): if entity in self._join_entities: continue @@ -3044,14 +3030,16 @@ class _MapperEntity(_QueryEntity): self.entities = [entity] self.entity_zero = self.expr = entity - def setup_entity(self, entity, mapper, adapter, - from_obj, is_aliased_class, with_polymorphic): + def setup_entity(self, entity, mapper, aliased_adapter, + from_obj, is_aliased_class, + with_polymorphic, + with_polymorphic_discriminator): self.mapper = mapper - self.adapter = adapter + self.aliased_adapter = aliased_adapter self.selectable = from_obj - self._with_polymorphic = with_polymorphic - self._polymorphic_discriminator = None self.is_aliased_class = is_aliased_class + self._with_polymorphic = with_polymorphic + self._polymorphic_discriminator = with_polymorphic_discriminator if is_aliased_class: self.path_entity = self.entity_zero = entity self._path = (entity,) @@ -3064,9 +3052,14 @@ class _MapperEntity(_QueryEntity): self.entity_zero = mapper self._label_name = self.mapper.class_.__name__ - def set_with_polymorphic(self, query, cls_or_mappers, - selectable, discriminator): + selectable, polymorphic_on): + if self.is_aliased_class: + raise NotImplementedError( + "Can't use with_polymorphic() against " + "an Aliased object" + ) + if cls_or_mappers is None: query._reset_polymorphic_adapter(self.mapper) return @@ -3074,15 +3067,12 @@ class _MapperEntity(_QueryEntity): mappers, from_obj = self.mapper._with_polymorphic_args( cls_or_mappers, selectable) self._with_polymorphic = mappers - self._polymorphic_discriminator = discriminator + self._polymorphic_discriminator = polymorphic_on - # TODO: do the wrapped thing here too so that - # with_polymorphic() can be applied to aliases - if not self.is_aliased_class: - self.selectable = from_obj - query._mapper_loads_polymorphically_with(self.mapper, - sql_util.ColumnAdapter(from_obj, - self.mapper._equivalent_columns)) + self.selectable = from_obj + query._mapper_loads_polymorphically_with(self.mapper, + sql_util.ColumnAdapter(from_obj, + self.mapper._equivalent_columns)) filter_fn = id @@ -3106,11 +3096,12 @@ class _MapperEntity(_QueryEntity): def _get_entity_clauses(self, query, context): adapter = None - if not self.is_aliased_class and query._polymorphic_adapters: - adapter = query._polymorphic_adapters.get(self.mapper, None) - if not adapter and self.adapter: - adapter = self.adapter + if not self.is_aliased_class: + if query._polymorphic_adapters: + adapter = query._polymorphic_adapters.get(self.mapper, None) + else: + adapter = self.aliased_adapter if adapter: if query._from_obj_alias: @@ -3196,7 +3187,10 @@ class _MapperEntity(_QueryEntity): column_collection=context.primary_columns ) - if self._polymorphic_discriminator is not None: + if self._polymorphic_discriminator is not None and \ + self._polymorphic_discriminator \ + is not self.mapper.polymorphic_on: + if adapter: pd = adapter.columns[self._polymorphic_discriminator] else: @@ -3246,8 +3240,9 @@ class _ColumnEntity(_QueryEntity): # can be located in the result even # if the expression's identity has been changed # due to adaption. - if not column._label: - column = column.label(None) + + if not column._label and not getattr(column, 'is_literal', False): + column = column.label(self._label_name) query._entities.append(self) @@ -3299,7 +3294,8 @@ class _ColumnEntity(_QueryEntity): c.entities = self.entities def setup_entity(self, entity, mapper, adapter, from_obj, - is_aliased_class, with_polymorphic): + is_aliased_class, with_polymorphic, + with_polymorphic_discriminator): if 'selectable' not in self.__dict__: self.selectable = from_obj self.froms.add(from_obj) diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py new file mode 100644 index 000000000..4c64e855f --- /dev/null +++ b/lib/sqlalchemy/orm/relationships.py @@ -0,0 +1,856 @@ +# orm/relationships.py +# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Heuristics related to join conditions as used in +:func:`.relationship`. + +Provides the :class:`.JoinCondition` object, which encapsulates +SQL annotation and aliasing behavior focused on the `primaryjoin` +and `secondaryjoin` aspects of :func:`.relationship`. + +""" + +from sqlalchemy import sql, util, log, exc as sa_exc, schema +from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ + join_condition, _shallow_annotate, visit_binary_product,\ + _deep_deannotate +from sqlalchemy.sql import operators, expression, visitors +from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY + +def remote(expr): + return _annotate_columns(expr, {"remote":True}) + +def foreign(expr): + return _annotate_columns(expr, {"foreign":True}) + +def remote_foreign(expr): + return _annotate_columns(expr, {"foreign":True, + "remote":True}) + +def _annotate_columns(element, annotations): + def clone(elem): + if isinstance(elem, expression.ColumnClause): + elem = elem._annotate(annotations.copy()) + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + +class JoinCondition(object): + def __init__(self, + parent_selectable, + child_selectable, + parent_local_selectable, + child_local_selectable, + primaryjoin=None, + secondary=None, + secondaryjoin=None, + parent_equivalents=None, + child_equivalents=None, + consider_as_foreign_keys=None, + local_remote_pairs=None, + remote_side=None, + self_referential=False, + prop=None, + support_sync=True, + can_be_synced_fn=lambda *c: True + ): + self.parent_selectable = parent_selectable + self.parent_local_selectable = parent_local_selectable + self.child_selectable = child_selectable + self.child_local_selectable = child_local_selectable + self.parent_equivalents = parent_equivalents + self.child_equivalents = child_equivalents + self.primaryjoin = primaryjoin + self.secondaryjoin = secondaryjoin + self.secondary = secondary + self.consider_as_foreign_keys = consider_as_foreign_keys + self._local_remote_pairs = local_remote_pairs + self._remote_side = remote_side + self.prop = prop + self.self_referential = self_referential + self.support_sync = support_sync + self.can_be_synced_fn = can_be_synced_fn + self._determine_joins() + self._annotate_fks() + self._annotate_remote() + self._annotate_local() + self._setup_pairs() + self._check_foreign_cols(self.primaryjoin, True) + if self.secondaryjoin is not None: + self._check_foreign_cols(self.secondaryjoin, False) + self._determine_direction() + self._check_remote_side() + self._log_joins() + + def _log_joins(self): + if self.prop is None: + return + log = self.prop.logger + log.info('%s setup primary join %s', self, + self.primaryjoin) + log.info('%s setup secondary join %s', self, + self.secondaryjoin) + log.info('%s synchronize pairs [%s]', self, + ','.join('(%s => %s)' % (l, r) for (l, r) in + self.synchronize_pairs)) + log.info('%s secondary synchronize pairs [%s]', self, + ','.join('(%s => %s)' % (l, r) for (l, r) in + self.secondary_synchronize_pairs or [])) + log.info('%s local/remote pairs [%s]', self, + ','.join('(%s / %s)' % (l, r) for (l, r) in + self.local_remote_pairs)) + log.info('%s relationship direction %s', self, + self.direction) + + def _determine_joins(self): + """Determine the 'primaryjoin' and 'secondaryjoin' attributes, + if not passed to the constructor already. + + This is based on analysis of the foreign key relationships + between the parent and target mapped selectables. + + """ + if self.secondaryjoin is not None and self.secondary is None: + raise sa_exc.ArgumentError( + "Property %s specified with secondary " + "join condition but " + "no secondary argument" % self.prop) + + # find a join between the given mapper's mapped table and + # the given table. will try the mapper's local table first + # for more specificity, then if not found will try the more + # general mapped table, which in the case of inheritance is + # a join. + try: + if self.secondary is not None: + if self.secondaryjoin is None: + self.secondaryjoin = \ + join_condition( + self.child_selectable, + self.secondary, + a_subset=self.child_local_selectable, + consider_as_foreign_keys=\ + self.consider_as_foreign_keys or None + ) + if self.primaryjoin is None: + self.primaryjoin = \ + join_condition( + self.parent_selectable, + self.secondary, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=\ + self.consider_as_foreign_keys or None + ) + else: + if self.primaryjoin is None: + self.primaryjoin = \ + join_condition( + self.parent_selectable, + self.child_selectable, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=\ + self.consider_as_foreign_keys or None + ) + except sa_exc.NoForeignKeysError, nfke: + if self.secondary is not None: + raise sa_exc.NoForeignKeysError("Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables via secondary table '%s'. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify 'primaryjoin' and 'secondaryjoin' " + "expressions." + % (self.prop, self.secondary)) + else: + raise sa_exc.NoForeignKeysError("Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify a 'primaryjoin' expression." + % self.prop) + except sa_exc.AmbiguousForeignKeysError, afke: + if self.secondary is not None: + raise sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables via secondary table '%s'. " + "Specify the 'foreign_keys' " + "argument, providing a list of those columns which " + "should be counted as containing a foreign key " + "reference from the secondary table to each of the " + "parent and child tables." + % (self.prop, self.secondary)) + else: + raise sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables. Specify the " + "'foreign_keys' argument, providing a list of those " + "columns which should be counted as containing a " + "foreign key reference to the parent table." + % self.prop) + + @util.memoized_property + def primaryjoin_reverse_remote(self): + """Return the primaryjoin condition suitable for the + "reverse" direction. + + If the primaryjoin was delivered here with pre-existing + "remote" annotations, the local/remote annotations + are reversed. Otherwise, the local/remote annotations + are removed. + + """ + if self._has_remote_annotations: + def replace(element): + if "remote" in element._annotations: + v = element._annotations.copy() + del v['remote'] + v['local'] = True + return element._with_annotations(v) + elif "local" in element._annotations: + v = element._annotations.copy() + del v['local'] + v['remote'] = True + return element._with_annotations(v) + return visitors.replacement_traverse( + self.primaryjoin, {}, replace) + else: + if self._has_foreign_annotations: + # TODO: coverage + return _deep_deannotate(self.primaryjoin, + values=("local", "remote")) + else: + return _deep_deannotate(self.primaryjoin) + + def _has_annotation(self, clause, annotation): + for col in visitors.iterate(clause, {}): + if annotation in col._annotations: + return True + else: + return False + + @util.memoized_property + def _has_foreign_annotations(self): + return self._has_annotation(self.primaryjoin, "foreign") + + @util.memoized_property + def _has_remote_annotations(self): + return self._has_annotation(self.primaryjoin, "remote") + + def _annotate_fks(self): + """Annotate the primaryjoin and secondaryjoin + structures with 'foreign' annotations marking columns + considered as foreign. + + """ + if self._has_foreign_annotations: + return + + if self.consider_as_foreign_keys: + self._annotate_from_fk_list() + else: + self._annotate_present_fks() + + def _annotate_from_fk_list(self): + def check_fk(col): + if col in self.consider_as_foreign_keys: + return col._annotate({"foreign":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, + {}, + check_fk + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, + {}, + check_fk + ) + + def _annotate_present_fks(self): + if self.secondary is not None: + secondarycols = util.column_set(self.secondary.c) + else: + secondarycols = set() + def is_foreign(a, b): + if isinstance(a, schema.Column) and \ + isinstance(b, schema.Column): + if a.references(b): + return a + elif b.references(a): + return b + + if secondarycols: + if a in secondarycols and b not in secondarycols: + return a + elif b in secondarycols and a not in secondarycols: + return b + + def visit_binary(binary): + if not isinstance(binary.left, sql.ColumnElement) or \ + not isinstance(binary.right, sql.ColumnElement): + return + + if "foreign" not in binary.left._annotations and \ + "foreign" not in binary.right._annotations: + col = is_foreign(binary.left, binary.right) + if col is not None: + if col.compare(binary.left): + binary.left = binary.left._annotate( + {"foreign":True}) + elif col.compare(binary.right): + binary.right = binary.right._annotate( + {"foreign":True}) + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, + {}, + {"binary":visit_binary} + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.cloned_traverse( + self.secondaryjoin, + {}, + {"binary":visit_binary} + ) + + def _refers_to_parent_table(self): + """Return True if the join condition contains column + comparisons where both columns are in both tables. + + """ + pt = self.parent_selectable + mt = self.child_selectable + result = [False] + def visit_binary(binary): + c, f = binary.left, binary.right + if ( + isinstance(c, expression.ColumnClause) and \ + isinstance(f, expression.ColumnClause) and \ + pt.is_derived_from(c.table) and \ + pt.is_derived_from(f.table) and \ + mt.is_derived_from(c.table) and \ + mt.is_derived_from(f.table) + ): + result[0] = True + visitors.traverse( + self.primaryjoin, + {}, + {"binary":visit_binary} + ) + return result[0] + + def _tables_overlap(self): + """Return True if parent/child tables have some overlap.""" + + return self.parent_selectable.is_derived_from( + self.child_local_selectable) or \ + self.child_selectable.is_derived_from( + self.parent_local_selectable) + + def _annotate_remote(self): + """Annotate the primaryjoin and secondaryjoin + structures with 'remote' annotations marking columns + considered as part of the 'remote' side. + + """ + if self._has_remote_annotations: + return + + parentcols = util.column_set(self.parent_selectable.c) + + if self.secondary is not None: + self._annotate_remote_secondary() + elif self._local_remote_pairs or self._remote_side: + self._annotate_remote_from_args() + elif self._refers_to_parent_table(): + self._annotate_selfref(lambda col:"foreign" in col._annotations) + elif self._tables_overlap(): + self._annotate_remote_with_overlap() + else: + self._annotate_remote_distinct_selectables() + + def _annotate_remote_secondary(self): + """annotate 'remote' in primaryjoin, secondaryjoin + when 'secondary' is present. + + """ + def repl(element): + if self.secondary.c.contains_column(element): + return element._annotate({"remote":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, repl) + + def _annotate_selfref(self, fn): + """annotate 'remote' in primaryjoin, secondaryjoin + when the relationship is detected as self-referential. + + """ + def visit_binary(binary): + equated = binary.left.compare(binary.right) + if isinstance(binary.left, expression.ColumnClause) and \ + isinstance(binary.right, expression.ColumnClause): + # assume one to many - FKs are "remote" + if fn(binary.left): + binary.left = binary.left._annotate({"remote":True}) + if fn(binary.right) and \ + not equated: + binary.right = binary.right._annotate( + {"remote":True}) + else: + self._warn_non_column_elements() + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, + {"binary":visit_binary}) + + def _annotate_remote_from_args(self): + """annotate 'remote' in primaryjoin, secondaryjoin + when the 'remote_side' or '_local_remote_pairs' + arguments are used. + + """ + if self._local_remote_pairs: + if self._remote_side: + raise sa_exc.ArgumentError( + "remote_side argument is redundant " + "against more detailed _local_remote_side " + "argument.") + + remote_side = [r for (l, r) in self._local_remote_pairs] + else: + remote_side = self._remote_side + + if self._refers_to_parent_table(): + self._annotate_selfref(lambda col:col in remote_side) + else: + def repl(element): + if element in remote_side: + return element._annotate({"remote":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + + def _annotate_remote_with_overlap(self): + """annotate 'remote' in primaryjoin, secondaryjoin + when the parent/child tables have some set of + tables in common, though is not a fully self-referential + relationship. + + """ + def visit_binary(binary): + binary.left, binary.right = proc_left_right(binary.left, + binary.right) + binary.right, binary.left = proc_left_right(binary.right, + binary.left) + def proc_left_right(left, right): + if isinstance(left, expression.ColumnClause) and \ + isinstance(right, expression.ColumnClause): + if self.child_selectable.c.contains_column(right) and \ + self.parent_selectable.c.contains_column(left): + right = right._annotate({"remote":True}) + else: + self._warn_non_column_elements() + + return left, right + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, + {"binary":visit_binary}) + + def _annotate_remote_distinct_selectables(self): + """annotate 'remote' in primaryjoin, secondaryjoin + when the parent/child tables are entirely + separate. + + """ + def repl(element): + if self.child_selectable.c.contains_column(element) and \ + ( + not self.parent_local_selectable.c.\ + contains_column(element) + or self.child_local_selectable.c.\ + contains_column(element) + ): + return element._annotate({"remote":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + + def _warn_non_column_elements(self): + util.warn( + "Non-simple column elements in primary " + "join condition for property %s - consider using " + "remote() annotations to mark the remote side." + % self.prop + ) + + def _annotate_local(self): + """Annotate the primaryjoin and secondaryjoin + structures with 'local' annotations. + + This annotates all column elements found + simultaneously in the parent table + and the join condition that don't have a + 'remote' annotation set up from + _annotate_remote() or user-defined. + + """ + if self._has_annotation(self.primaryjoin, "local"): + return + + parentcols = util.column_set(self.parent_selectable.c) + + if self._local_remote_pairs: + local_side = util.column_set([l for (l, r) + in self._local_remote_pairs]) + else: + local_side = util.column_set(self.parent_selectable.c) + + def locals_(elem): + if "remote" not in elem._annotations and \ + elem in local_side: + return elem._annotate({"local":True}) + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, locals_ + ) + + def _check_remote_side(self): + if not self.local_remote_pairs: + raise sa_exc.ArgumentError('Relationship %s could ' + 'not determine any unambiguous local/remote column ' + 'pairs based on join condition and remote_side ' + 'arguments. ' + 'Consider using the remote() annotation to ' + 'accurately mark those elements of the join ' + 'condition that are on the remote side of ' + 'the relationship.' + % (self.prop, )) + + def _check_foreign_cols(self, join_condition, primary): + """Check the foreign key columns collected and emit error + messages.""" + + can_sync = False + + foreign_cols = self._gather_columns_with_annotation( + join_condition, "foreign") + + has_foreign = bool(foreign_cols) + + if primary: + can_sync = bool(self.synchronize_pairs) + else: + can_sync = bool(self.secondary_synchronize_pairs) + + if self.support_sync and can_sync or \ + (not self.support_sync and has_foreign): + return + + # from here below is just determining the best error message + # to report. Check for a join condition using any operator + # (not just ==), perhaps they need to turn on "viewonly=True". + if self.support_sync and has_foreign and not can_sync: + err = "Could not locate any simple equality expressions "\ + "involving locally mapped foreign key columns for "\ + "%s join condition "\ + "'%s' on relationship %s." % ( + primary and 'primary' or 'secondary', + join_condition, + self.prop + ) + err += \ + " Ensure that referencing columns are associated "\ + "with a ForeignKey or ForeignKeyConstraint, or are "\ + "annotated in the join condition with the foreign() "\ + "annotation. To allow comparison operators other than "\ + "'==', the relationship can be marked as viewonly=True." + + raise sa_exc.ArgumentError(err) + else: + err = "Could not locate any relevant foreign key columns "\ + "for %s join condition '%s' on relationship %s." % ( + primary and 'primary' or 'secondary', + join_condition, + self.prop + ) + err += \ + ' Ensure that referencing columns are associated '\ + 'with a ForeignKey or ForeignKeyConstraint, or are '\ + 'annotated in the join condition with the foreign() '\ + 'annotation.' + raise sa_exc.ArgumentError(err) + + def _determine_direction(self): + """Determine if this relationship is one to many, many to one, + many to many. + + """ + if self.secondaryjoin is not None: + self.direction = MANYTOMANY + else: + parentcols = util.column_set(self.parent_selectable.c) + targetcols = util.column_set(self.child_selectable.c) + + # fk collection which suggests ONETOMANY. + onetomany_fk = targetcols.intersection( + self.foreign_key_columns) + + # fk collection which suggests MANYTOONE. + + manytoone_fk = parentcols.intersection( + self.foreign_key_columns) + + if onetomany_fk and manytoone_fk: + # fks on both sides. test for overlap of local/remote + # with foreign key + self_equated = self.remote_columns.intersection( + self.local_columns + ) + onetomany_local = self.remote_columns.\ + intersection(self.foreign_key_columns).\ + difference(self_equated) + manytoone_local = self.local_columns.\ + intersection(self.foreign_key_columns).\ + difference(self_equated) + if onetomany_local and not manytoone_local: + self.direction = ONETOMANY + elif manytoone_local and not onetomany_local: + self.direction = MANYTOONE + else: + raise sa_exc.ArgumentError( + "Can't determine relationship" + " direction for relationship '%s' - foreign " + "key columns within the join condition are present " + "in both the parent and the child's mapped tables. " + "Ensure that only those columns referring " + "to a parent column are marked as foreign, " + "either via the foreign() annotation or " + "via the foreign_keys argument." + % self.prop) + elif onetomany_fk: + self.direction = ONETOMANY + elif manytoone_fk: + self.direction = MANYTOONE + else: + raise sa_exc.ArgumentError("Can't determine relationship " + "direction for relationship '%s' - foreign " + "key columns are present in neither the parent " + "nor the child's mapped tables" % self.prop) + + def _deannotate_pairs(self, collection): + """provide deannotation for the various lists of + pairs, so that using them in hashes doesn't incur + high-overhead __eq__() comparisons against + original columns mapped. + + """ + return [(x._deannotate(), y._deannotate()) + for x, y in collection] + + def _setup_pairs(self): + sync_pairs = [] + lrp = util.OrderedSet([]) + secondary_sync_pairs = [] + + def go(joincond, collection): + def visit_binary(binary, left, right): + if "remote" in right._annotations and \ + "remote" not in left._annotations and \ + self.can_be_synced_fn(left): + lrp.add((left, right)) + elif "remote" in left._annotations and \ + "remote" not in right._annotations and \ + self.can_be_synced_fn(right): + lrp.add((right, left)) + if binary.operator is operators.eq and \ + self.can_be_synced_fn(left, right): + if "foreign" in right._annotations: + collection.append((left, right)) + elif "foreign" in left._annotations: + collection.append((right, left)) + visit_binary_product(visit_binary, joincond) + + for joincond, collection in [ + (self.primaryjoin, sync_pairs), + (self.secondaryjoin, secondary_sync_pairs) + ]: + if joincond is None: + continue + go(joincond, collection) + + self.local_remote_pairs = self._deannotate_pairs(lrp) + self.synchronize_pairs = self._deannotate_pairs(sync_pairs) + self.secondary_synchronize_pairs = self._deannotate_pairs(secondary_sync_pairs) + + @util.memoized_property + def remote_columns(self): + return self._gather_join_annotations("remote") + + @util.memoized_property + def local_columns(self): + return self._gather_join_annotations("local") + + @util.memoized_property + def foreign_key_columns(self): + return self._gather_join_annotations("foreign") + + @util.memoized_property + def deannotated_primaryjoin(self): + return _deep_deannotate(self.primaryjoin) + + @util.memoized_property + def deannotated_secondaryjoin(self): + if self.secondaryjoin is not None: + return _deep_deannotate(self.secondaryjoin) + else: + return None + + def _gather_join_annotations(self, annotation): + s = set( + self._gather_columns_with_annotation( + self.primaryjoin, annotation) + ) + if self.secondaryjoin is not None: + s.update( + self._gather_columns_with_annotation( + self.secondaryjoin, annotation) + ) + return set([x._deannotate() for x in s]) + + def _gather_columns_with_annotation(self, clause, *annotation): + annotation = set(annotation) + return set([ + col for col in visitors.iterate(clause, {}) + if annotation.issubset(col._annotations) + ]) + + + def join_targets(self, source_selectable, + dest_selectable, + aliased, + single_crit=None): + """Given a source and destination selectable, create a + join between them. + + This takes into account aliasing the join clause + to reference the appropriate corresponding columns + in the target objects, as well as the extra child + criterion, equivalent column sets, etc. + + """ + + # place a barrier on the destination such that + # replacement traversals won't ever dig into it. + # its internal structure remains fixed + # regardless of context. + dest_selectable = _shallow_annotate( + dest_selectable, + {'no_replacement_traverse':True}) + + primaryjoin, secondaryjoin, secondary = self.primaryjoin, \ + self.secondaryjoin, self.secondary + + # adjust the join condition for single table inheritance, + # in the case that the join is to a subclass + # this is analogous to the + # "_adjust_for_single_table_inheritance()" method in Query. + + if single_crit is not None: + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & single_crit + else: + primaryjoin = primaryjoin & single_crit + + if aliased: + if secondary is not None: + secondary = secondary.alias() + primary_aliasizer = ClauseAdapter(secondary) + secondary_aliasizer = \ + ClauseAdapter(dest_selectable, + equivalents=self.child_equivalents).\ + chain(primary_aliasizer) + if source_selectable is not None: + primary_aliasizer = \ + ClauseAdapter(secondary).\ + chain(ClauseAdapter(source_selectable, + equivalents=self.parent_equivalents)) + secondaryjoin = \ + secondary_aliasizer.traverse(secondaryjoin) + else: + primary_aliasizer = ClauseAdapter(dest_selectable, + exclude_fn=lambda c: "local" in c._annotations, + equivalents=self.child_equivalents) + if source_selectable is not None: + primary_aliasizer.chain( + ClauseAdapter(source_selectable, + exclude_fn=lambda c: "remote" in c._annotations, + equivalents=self.parent_equivalents)) + secondary_aliasizer = None + + primaryjoin = primary_aliasizer.traverse(primaryjoin) + target_adapter = secondary_aliasizer or primary_aliasizer + target_adapter.exclude_fn = None + else: + target_adapter = None + return primaryjoin, secondaryjoin, secondary, \ + target_adapter, dest_selectable + + def create_lazy_clause(self, reverse_direction=False): + binds = util.column_dict() + lookup = util.column_dict() + equated_columns = util.column_dict() + + if reverse_direction and self.secondaryjoin is None: + for l, r in self.local_remote_pairs: + _list = lookup.setdefault(r, []) + _list.append((r, l)) + equated_columns[l] = r + else: + for l, r in self.local_remote_pairs: + _list = lookup.setdefault(l, []) + _list.append((l, r)) + equated_columns[r] = l + + def col_to_bind(col): + if col in lookup: + for tobind, equated in lookup[col]: + if equated in binds: + return None + if col not in binds: + binds[col] = sql.bindparam( + None, None, type_=col.type, unique=True) + return binds[col] + return None + + lazywhere = self.deannotated_primaryjoin + + if self.deannotated_secondaryjoin is None or not reverse_direction: + lazywhere = visitors.replacement_traverse( + lazywhere, {}, col_to_bind) + + if self.deannotated_secondaryjoin is not None: + secondaryjoin = self.deannotated_secondaryjoin + if reverse_direction: + secondaryjoin = visitors.replacement_traverse( + secondaryjoin, {}, col_to_bind) + lazywhere = sql.and_(lazywhere, secondaryjoin) + + bind_to_col = dict((binds[col].key, col) for col in binds) + + return lazywhere, bind_to_col, equated_columns + + + diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 14778705d..7c2cd8f0e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -16,7 +16,7 @@ from sqlalchemy.orm import ( from sqlalchemy.orm.util import object_mapper as _object_mapper from sqlalchemy.orm.util import class_mapper as _class_mapper from sqlalchemy.orm.util import ( - _class_to_mapper, _state_mapper, + _class_to_mapper, _state_mapper, object_state ) from sqlalchemy.orm.mapper import Mapper, _none_set from sqlalchemy.orm.unitofwork import UOWTransaction @@ -835,7 +835,7 @@ class Session(object): """ for state in self.identity_map.all_states() + list(self._new): - state.detach() + state._detach() self.identity_map = self._identity_cls() self._new = {} @@ -1135,7 +1135,7 @@ class Session(object): state.expire(state.dict, self.identity_map._modified) elif state in self._new: self._new.pop(state) - state.detach() + state._detach() @util.deprecated("0.7", "The non-weak-referencing identity map " "feature is no longer needed.") @@ -1177,11 +1177,11 @@ class Session(object): def _expunge_state(self, state): if state in self._new: self._new.pop(state) - state.detach() + state._detach() elif self.identity_map.contains_state(state): self.identity_map.discard(state) self._deleted.pop(state, None) - state.detach() + state._detach() elif self.transaction: self.transaction._deleted.pop(state, None) @@ -1460,10 +1460,10 @@ class Session(object): "Object '%s' already has an identity - it can't be registered " "as pending" % mapperutil.state_str(state)) - self._attach(state) if state not in self._new: self._new[state] = state.obj() state.insert_order = len(self._new) + self._attach(state) def _update_impl(self, state): if (self.identity_map.contains_state(state) and @@ -1481,9 +1481,9 @@ class Session(object): "function to send this object back to the transient state." % mapperutil.state_str(state) ) - self._attach(state) self._deleted.pop(state, None) self.identity_map.add(state) + self._attach(state) def _save_or_update_impl(self, state): if state.key is None: @@ -1678,7 +1678,7 @@ class Session(object): def is_modified(self, instance, include_collections=True, - passive=attributes.PASSIVE_OFF): + passive=True): """Return ``True`` if the given instance has locally modified attributes. @@ -1693,19 +1693,19 @@ class Session(object): E.g.:: - return session.is_modified(someobject, passive=True) + return session.is_modified(someobject) .. note:: - In SQLAlchemy 0.7 and earlier, the ``passive`` - flag should **always** be explicitly set to ``True``. - The current default value of :data:`.attributes.PASSIVE_OFF` - for this flag is incorrect, in that it loads unloaded - collections and attributes which by definition - have no modified state, and furthermore trips off - autoflush which then causes all subsequent, possibly - modified attributes to lose their modified state. - The default value of the flag will be changed in 0.8. + When using SQLAlchemy 0.7 and earlier, the ``passive`` + flag should **always** be explicitly set to ``True``, + else SQL loads/autoflushes may proceed which can affect + the modified state itself:: + + session.is_modified(someobject, passive=True) + + In 0.8 and above, the behavior is corrected and + this flag is ignored. A few caveats to this method apply: @@ -1726,7 +1726,7 @@ class Session(object): usually needed, and in those few cases where it isn't, is less expensive on average than issuing a defensive SELECT. - The "old" value is fetched unconditionally only if the attribute + The "old" value is fetched unconditionally upon set only if the attribute container has the ``active_history`` flag set to ``True``. This flag is set typically for primary key attributes and scalar object references that are not a simple many-to-one. To set this flag for @@ -1739,33 +1739,17 @@ class Session(object): only local-column based properties (i.e. scalar columns or many-to-one foreign keys) that would result in an UPDATE for this instance upon flush. - :param passive: Indicates if unloaded attributes and - collections should be loaded in the course of performing - this test. If set to ``False``, or left at its default - value of :data:`.PASSIVE_OFF`, unloaded attributes - will be loaded. If set to ``True`` or - :data:`.PASSIVE_NO_INITIALIZE`, unloaded - collections and attributes will remain unloaded. As - noted previously, the existence of this flag here - is a bug, as unloaded attributes by definition have - no changes, and the load operation also triggers an - autoflush which then cancels out subsequent changes. - This flag should **always be set to - True**. In 0.8 the flag will be deprecated and the default - set to ``True``. - + :param passive: Ignored for backwards compatibility in + 0.8 and above. When using SQLAlchemy 0.7 and earlier, this + flag should always be set to ``True``. """ - try: - state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) - dict_ = state.dict + state = object_state(instance) - if passive is True: - passive = attributes.PASSIVE_NO_INITIALIZE - elif passive is False: - passive = attributes.PASSIVE_OFF + if not state.modified: + return False + + dict_ = state.dict for attr in state.manager.attributes: if \ @@ -1776,11 +1760,13 @@ class Session(object): continue (added, unchanged, deleted) = \ - attr.impl.get_history(state, dict_, passive=passive) + attr.impl.get_history(state, dict_, + passive=attributes.NO_CHANGE) if added or deleted: return True - return False + else: + return False @property def is_active(self): diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 4803ecdc3..bb6104762 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -17,12 +17,12 @@ from sqlalchemy import util from sqlalchemy.orm import exc as orm_exc, attributes, interfaces,\ util as orm_util -from sqlalchemy.orm.attributes import PASSIVE_OFF, PASSIVE_NO_RESULT, \ - PASSIVE_NO_FETCH, NEVER_SET, ATTR_WAS_SET, NO_VALUE +from sqlalchemy.orm.attributes import PASSIVE_NO_RESULT, \ + SQL_OK, NEVER_SET, ATTR_WAS_SET, NO_VALUE,\ + PASSIVE_NO_INITIALIZE mapperlib = util.importlater("sqlalchemy.orm", "mapperlib") - -import sys +sessionlib = util.importlater("sqlalchemy.orm", "session") class InstanceState(object): """tracks state information at the instance level.""" @@ -33,7 +33,6 @@ class InstanceState(object): load_options = EMPTY_SET load_path = () insert_order = None - mutable_dict = None _strong_obj = None modified = False expired = False @@ -47,22 +46,81 @@ class InstanceState(object): self.committed_state = {} @util.memoized_property + def attr(self): + return util.ImmutableProperties( + dict( + (key, InspectAttr(self, key)) + for key in self.manager + ) + ) + + @property + def transient(self): + return self.key is None and \ + not self._attached + + @property + def pending(self): + return self.key is None and \ + self._attached + + @property + def persistent(self): + return self.key is not None and \ + self._attached + + @property + def detached(self): + return self.key is not None and \ + not self._attached + + @property + def _attached(self): + return self.session_id is not None and \ + self.session_id in sessionlib._sessions + + @property + def session(self): + return sessionlib._state_session(self) + + @property + def object(self): + return self.obj() + + @property + def identity(self): + if self.key is None: + return None + else: + return self.key[1] + + @property + def identity_key(self): + # TODO: just change .key to .identity_key across + # the board ? probably + return self.key + + @util.memoized_property def parents(self): return {} @util.memoized_property - def pending(self): + def _pending_mutations(self): return {} + @util.memoized_property + def mapper(self): + return self.manager.mapper + @property def has_identity(self): return bool(self.key) - def detach(self): + def _detach(self): self.session_id = None - def dispose(self): - self.detach() + def _dispose(self): + self._detach() del self.obj def _cleanup(self, ref): @@ -91,9 +149,6 @@ class InstanceState(object): manager.dispatch.init(self, args, kwargs) - #if manager.mutable_attributes: - # assert self.__class__ is MutableAttrInstanceState - try: return manager.original_init(*mixed[1:], **kwargs) except: @@ -106,36 +161,17 @@ class InstanceState(object): def get_impl(self, key): return self.manager[key].impl - def get_pending(self, key): - if key not in self.pending: - self.pending[key] = PendingCollection() - return self.pending[key] - - def value_as_iterable(self, dict_, key, passive=PASSIVE_OFF): - """Return a list of tuples (state, obj) for the given - key. - - returns an empty list if the value is None/empty/PASSIVE_NO_RESULT - """ - - impl = self.manager[key].impl - x = impl.get(self, dict_, passive=passive) - if x is PASSIVE_NO_RESULT or x is None: - return [] - elif hasattr(impl, 'get_collection'): - return [ - (attributes.instance_state(o), o) for o in - impl.get_collection(self, dict_, x, passive=passive) - ] - else: - return [(attributes.instance_state(x), x)] + def _get_pending_mutation(self, key): + if key not in self._pending_mutations: + self._pending_mutations[key] = PendingCollection() + return self._pending_mutations[key] def __getstate__(self): d = {'instance':self.obj()} d.update( (k, self.__dict__[k]) for k in ( - 'committed_state', 'pending', 'modified', 'expired', - 'callables', 'key', 'parents', 'load_options', 'mutable_dict', + 'committed_state', '_pending_mutations', 'modified', 'expired', + 'callables', 'key', 'parents', 'load_options', 'class_', ) if k in self.__dict__ ) @@ -169,7 +205,7 @@ class InstanceState(object): mapperlib.configure_mappers() self.committed_state = state.get('committed_state', {}) - self.pending = state.get('pending', {}) + self._pending_mutations = state.get('_pending_mutations', {}) self.parents = state.get('parents', {}) self.modified = state.get('modified', False) self.expired = state.get('expired', False) @@ -180,7 +216,7 @@ class InstanceState(object): self.__dict__.update([ (k, state[k]) for k in ( - 'key', 'load_options', 'mutable_dict' + 'key', 'load_options', ) if k in state ]) @@ -234,8 +270,7 @@ class InstanceState(object): self.committed_state.clear() - self.__dict__.pop('pending', None) - self.__dict__.pop('mutable_dict', None) + self.__dict__.pop('_pending_mutations', None) # clear out 'parents' collection. not # entirely clear how we can best determine @@ -252,8 +287,7 @@ class InstanceState(object): self.manager.dispatch.expire(self, None) def expire_attributes(self, dict_, attribute_names): - pending = self.__dict__.get('pending', None) - mutable_dict = self.mutable_dict + pending = self.__dict__.get('_pending_mutations', None) for key in attribute_names: impl = self.manager[key].impl @@ -262,8 +296,6 @@ class InstanceState(object): dict_.pop(key, None) self.committed_state.pop(key, None) - if mutable_dict: - mutable_dict.pop(key, None) if pending: pending.pop(key, None) @@ -276,7 +308,7 @@ class InstanceState(object): """ - if passive is PASSIVE_NO_FETCH: + if not passive & SQL_OK: return PASSIVE_NO_RESULT toload = self.expired_attributes.\ @@ -336,7 +368,7 @@ class InstanceState(object): def _is_really_none(self): return self.obj() - def modified_event(self, dict_, attr, previous, collection=False): + def _modified_event(self, dict_, attr, previous, collection=False): if attr.key not in self.committed_state: if collection: if previous is NEVER_SET: @@ -381,15 +413,8 @@ class InstanceState(object): """ class_manager = self.manager - if class_manager.mutable_attributes: - for key in keys: - if key in dict_ and key in class_manager.mutable_attributes: - self.committed_state[key] = self.manager[key].impl.copy(dict_[key]) - else: - self.committed_state.pop(key, None) - else: - for key in keys: - self.committed_state.pop(key, None) + for key in keys: + self.committed_state.pop(key, None) self.expired = False @@ -415,131 +440,40 @@ class InstanceState(object): """ self.committed_state.clear() - self.__dict__.pop('pending', None) + self.__dict__.pop('_pending_mutations', None) callables = self.callables for key in list(callables): if key in dict_ and callables[key] is self: del callables[key] - for key in self.manager.mutable_attributes: - if key in dict_: - self.committed_state[key] = self.manager[key].impl.copy(dict_[key]) - if instance_dict and self.modified: instance_dict._modified.discard(self) self.modified = self.expired = False self._strong_obj = None -class MutableAttrInstanceState(InstanceState): - """InstanceState implementation for objects that reference 'mutable' - attributes. +class InspectAttr(object): + """Provide inspection interface to an object's state.""" - Has a more involved "cleanup" handler that checks mutable attributes - for changes upon dereference, resurrecting if needed. - - """ - - @util.memoized_property - def mutable_dict(self): - return {} - - def _get_modified(self, dict_=None): - if self.__dict__.get('modified', False): - return True - else: - if dict_ is None: - dict_ = self.dict - for key in self.manager.mutable_attributes: - if self.manager[key].impl.check_mutable_modified(self, dict_): - return True - else: - return False - - def _set_modified(self, value): - self.__dict__['modified'] = value - - modified = property(_get_modified, _set_modified) + def __init__(self, state, key): + self.state = state + self.key = key @property - def unmodified(self): - """a set of keys which have no uncommitted changes""" - - dict_ = self.dict - - return set([ - key for key in self.manager - if (key not in self.committed_state or - (key in self.manager.mutable_attributes and - not self.manager[key].impl.check_mutable_modified(self, dict_)))]) + def loaded_value(self): + return self.state.dict.get(self.key, NO_VALUE) - def unmodified_intersection(self, keys): - """Return self.unmodified.intersection(keys).""" - - dict_ = self.dict - - return set([ - key for key in keys - if (key not in self.committed_state or - (key in self.manager.mutable_attributes and - not self.manager[key].impl.check_mutable_modified(self, dict_)))]) - - - def _is_really_none(self): - """do a check modified/resurrect. - - This would be called in the extremely rare - race condition that the weakref returned None but - the cleanup handler had not yet established the - __resurrect callable as its replacement. - - """ - if self.modified: - self.obj = self.__resurrect - return self.obj() - else: - return None - - def reset(self, dict_, key): - self.mutable_dict.pop(key, None) - InstanceState.reset(self, dict_, key) - - def _cleanup(self, ref): - """weakref callback. - - This method may be called by an asynchronous - gc. - - If the state shows pending changes, the weakref - is replaced by the __resurrect callable which will - re-establish an object reference on next access, - else removes this InstanceState from the owning - identity map, if any. - - """ - if self._get_modified(self.mutable_dict): - self.obj = self.__resurrect - else: - instance_dict = self._instance_dict() - if instance_dict: - instance_dict.discard(self) - self.dispose() - - def __resurrect(self): - """A substitute for the obj() weakref function which resurrects.""" - - # store strong ref'ed version of the object; will revert - # to weakref when changes are persisted - obj = self.manager.new_instance(state=self) - self.obj = weakref.ref(obj, self._cleanup) - self._strong_obj = obj - obj.__dict__.update(self.mutable_dict) + @property + def value(self): + return self.state.manager[self.key].__get__( + self.state.obj(), self.state.class_) - # re-establishes identity attributes from the key - self.manager.dispatch.resurrect(self) + @property + def history(self): + return self.state.get_history(self.key, + PASSIVE_NO_INITIALIZE) - return obj class PendingCollection(object): """A writable placeholder for an unloaded collection. diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 5f4b182d0..4cf32335f 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -25,8 +25,6 @@ import itertools def _register_attribute(strategy, mapper, useobject, compare_function=None, typecallable=None, - copy_function=None, - mutable_scalars=False, uselist=False, callable_=None, proxy_property=None, @@ -45,11 +43,11 @@ def _register_attribute(strategy, mapper, useobject, listen_hooks.append(single_parent_validator) if prop.key in prop.parent.validators: + fn, include_removes = prop.parent.validators[prop.key] listen_hooks.append( lambda desc, prop: mapperutil._validator_events(desc, - prop.key, - prop.parent.validators[prop.key]) - ) + prop.key, fn, include_removes) + ) if useobject: listen_hooks.append(unitofwork.track_cascade_events) @@ -71,9 +69,7 @@ def _register_attribute(strategy, mapper, useobject, m.class_, prop.key, parent_token=prop, - mutable_scalars=mutable_scalars, uselist=uselist, - copy_function=copy_function, compare_function=compare_function, useobject=useobject, extension=attribute_ext, @@ -132,8 +128,6 @@ class ColumnLoader(LoaderStrategy): _register_attribute(self, mapper, useobject=False, compare_function=coltype.compare_values, - copy_function=coltype.copy_value, - mutable_scalars=self.columns[0].type.is_mutable(), active_history = active_history ) @@ -193,8 +187,6 @@ class DeferredColumnLoader(LoaderStrategy): _register_attribute(self, mapper, useobject=False, compare_function=self.columns[0].type.compare_values, - copy_function=self.columns[0].type.copy_value, - mutable_scalars=self.columns[0].type.is_mutable(), callable_=self._load_for_state, expire_missing=False ) @@ -213,7 +205,7 @@ class DeferredColumnLoader(LoaderStrategy): if not state.key: return attributes.ATTR_EMPTY - if passive is attributes.PASSIVE_NO_FETCH: + if not passive & attributes.SQL_OK: return attributes.PASSIVE_NO_RESULT localparent = state.manager.mapper @@ -324,14 +316,14 @@ class LazyLoader(AbstractRelationshipLoader): def init(self): super(LazyLoader, self).init() + join_condition = self.parent_property._join_condition self._lazywhere, \ self._bind_to_col, \ - self._equated_columns = self._create_lazy_clause(self.parent_property) + self._equated_columns = join_condition.create_lazy_clause() self._rev_lazywhere, \ self._rev_bind_to_col, \ - self._rev_equated_columns = self._create_lazy_clause( - self.parent_property, + self._rev_equated_columns = join_condition.create_lazy_clause( reverse_direction=True) self.logger.info("%s lazy loading clause %s", self, self._lazywhere) @@ -464,13 +456,10 @@ class LazyLoader(AbstractRelationshipLoader): ident_key = None if ( - (passive is attributes.PASSIVE_NO_FETCH or \ - passive is attributes.PASSIVE_NO_FETCH_RELATED) and - not self.use_get - ) or ( - passive is attributes.PASSIVE_ONLY_PERSISTENT and - pending - ): + (not passive & attributes.SQL_OK and not self.use_get) + or + (not passive & attributes.NON_PERSISTENT_OK and pending) + ): return attributes.PASSIVE_NO_RESULT session = sessionlib._state_session(state) @@ -501,8 +490,8 @@ class LazyLoader(AbstractRelationshipLoader): instance = Query._get_from_identity(session, ident_key, passive) if instance is not None: return instance - elif passive is attributes.PASSIVE_NO_FETCH or \ - passive is attributes.PASSIVE_NO_FETCH_RELATED: + elif not passive & attributes.SQL_OK or \ + not passive & attributes.RELATED_OBJECT_OK: return attributes.PASSIVE_NO_RESULT return self._emit_lazyload(session, state, ident_key) @@ -517,17 +506,12 @@ class LazyLoader(AbstractRelationshipLoader): dict_ = state.dict - if passive is attributes.PASSIVE_NO_FETCH_RELATED: - attr_passive = attributes.PASSIVE_OFF - else: - attr_passive = passive - return [ get_attr( state, dict_, self._equated_columns[pk], - passive=attr_passive) + passive=passive) for pk in self.mapper.primary_key ] @@ -617,49 +601,6 @@ class LazyLoader(AbstractRelationshipLoader): return reset_for_lazy_callable, None, None - @classmethod - def _create_lazy_clause(cls, prop, reverse_direction=False): - binds = util.column_dict() - lookup = util.column_dict() - equated_columns = util.column_dict() - - if reverse_direction and prop.secondaryjoin is None: - for l, r in prop.local_remote_pairs: - _list = lookup.setdefault(r, []) - _list.append((r, l)) - equated_columns[l] = r - else: - for l, r in prop.local_remote_pairs: - _list = lookup.setdefault(l, []) - _list.append((l, r)) - equated_columns[r] = l - - def col_to_bind(col): - if col in lookup: - for tobind, equated in lookup[col]: - if equated in binds: - return None - if col not in binds: - binds[col] = sql.bindparam(None, None, type_=col.type, unique=True) - return binds[col] - return None - - lazywhere = prop.primaryjoin - - if prop.secondaryjoin is None or not reverse_direction: - lazywhere = visitors.replacement_traverse( - lazywhere, {}, col_to_bind) - - if prop.secondaryjoin is not None: - secondaryjoin = prop.secondaryjoin - if reverse_direction: - secondaryjoin = visitors.replacement_traverse( - secondaryjoin, {}, col_to_bind) - lazywhere = sql.and_(lazywhere, secondaryjoin) - - bind_to_col = dict((binds[col].key, col) for col in binds) - - return lazywhere, bind_to_col, equated_columns log.class_logger(LazyLoader) @@ -785,7 +726,8 @@ class SubqueryLoader(AbstractRelationshipLoader): leftmost_mapper, leftmost_prop = \ subq_mapper, \ subq_mapper._props[subq_path[1]] - leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop) + + leftmost_cols = leftmost_prop.local_columns leftmost_attr = [ leftmost_mapper._columntoproperty[c].class_attribute @@ -799,7 +741,7 @@ class SubqueryLoader(AbstractRelationshipLoader): ): # reformat the original query # to look only for significant columns - q = orig_query._clone() + q = orig_query._clone().correlate(None) # TODO: why does polymporphic etc. require hardcoding # into _adapt_col_list ? Does query.add_columns(...) work @@ -846,8 +788,7 @@ class SubqueryLoader(AbstractRelationshipLoader): # self.parent is more specific than subq_path[-2] parent_alias = mapperutil.AliasedClass(self.parent) - local_cols, remote_cols = \ - self._local_remote_columns(self.parent_property) + local_cols = self.parent_property.local_columns local_attr = [ getattr(parent_alias, self.parent._columntoproperty[c].key) @@ -881,17 +822,6 @@ class SubqueryLoader(AbstractRelationshipLoader): q = q.join(attr, aliased=middle, from_joinpoint=True) return q - def _local_remote_columns(self, prop): - if prop.secondary is None: - return zip(*prop.local_remote_pairs) - else: - return \ - [p[0] for p in prop.synchronize_pairs],\ - [ - p[0] for p in prop. - secondary_synchronize_pairs - ] - def _setup_options(self, q, subq_path, orig_query): # propagate loader options etc. to the new query. # these will fire relative to subq_path. @@ -930,7 +860,7 @@ class SubqueryLoader(AbstractRelationshipLoader): if ('subquery', reduced_path) not in context.attributes: return None, None, None - local_cols, remote_cols = self._local_remote_columns(self.parent_property) + local_cols = self.parent_property.local_columns q = context.attributes[('subquery', reduced_path)] diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 8fc5f139d..3523e7d06 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -166,8 +166,9 @@ class UOWTransaction(object): history, state_history, cached_passive = self.attributes[hashkey] # if the cached lookup was "passive" and now # we want non-passive, do a non-passive lookup and re-cache - if cached_passive is not attributes.PASSIVE_OFF \ - and passive is attributes.PASSIVE_OFF: + + if not cached_passive & attributes.SQL_OK \ + and passive & attributes.SQL_OK: impl = state.manager[key].impl history = impl.get_history(state, state.dict, attributes.PASSIVE_OFF) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 0c5f203a7..5fcb15a9a 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -5,7 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, event, exc as sa_exc +from sqlalchemy import sql, util, event, exc as sa_exc, inspection from sqlalchemy.sql import expression, util as sql_util, operators from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\ PropComparator, MapperProperty @@ -68,24 +68,36 @@ class CascadeOptions(frozenset): ",".join([x for x in sorted(self)]) ) -def _validator_events(desc, key, validator): +def _validator_events(desc, key, validator, include_removes): """Runs a validation method on an attribute value to be set or appended.""" - def append(state, value, initiator): - return validator(state.obj(), key, value) + if include_removes: + def append(state, value, initiator): + return validator(state.obj(), key, value, False) - def set_(state, value, oldvalue, initiator): - return validator(state.obj(), key, value) + def set_(state, value, oldvalue, initiator): + return validator(state.obj(), key, value, False) + + def remove(state, value, initiator): + validator(state.obj(), key, value, True) + else: + def append(state, value, initiator): + return validator(state.obj(), key, value) + + def set_(state, value, oldvalue, initiator): + return validator(state.obj(), key, value) event.listen(desc, 'append', append, raw=True, retval=True) event.listen(desc, 'set', set_, raw=True, retval=True) + if include_removes: + event.listen(desc, "remove", remove, raw=True, retval=True) def polymorphic_union(table_map, typecolname, aliasname='p_union', cast_nulls=True): """Create a ``UNION`` statement used by a polymorphic mapper. See :ref:`concrete_inheritance` for an example of how this is used. - + :param table_map: mapping of polymorphic identities to :class:`.Table` objects. :param typecolname: string name of a "discriminator" column, which will be @@ -236,7 +248,7 @@ class AliasedClass(object): session.query(User, user_alias).\\ join((user_alias, User.id > user_alias.id)).\\ filter(User.name==user_alias.name) - + The resulting object is an instance of :class:`.AliasedClass`, however it implements a ``__getattribute__()`` scheme which will proxy attribute access to that of the ORM class being aliased. All classmethods @@ -244,7 +256,7 @@ class AliasedClass(object): hybrids created with the :ref:`hybrids_toplevel` extension, which will receive the :class:`.AliasedClass` as the "class" argument when classmethods are called. - + :param cls: ORM mapped entity which will be "wrapped" around an alias. :param alias: a selectable, such as an :func:`.alias` or :func:`.select` construct, which will be rendered in place of the mapped table of the @@ -259,39 +271,51 @@ class AliasedClass(object): otherwise have a column that corresponds to one on the entity. The use case for this is when associating an entity with some derived selectable such as one that uses aggregate functions:: - + class UnitPrice(Base): __tablename__ = 'unit_price' ... unit_id = Column(Integer) price = Column(Numeric) - + aggregated_unit_price = Session.query( func.sum(UnitPrice.price).label('price') ).group_by(UnitPrice.unit_id).subquery() - + aggregated_unit_price = aliased(UnitPrice, alias=aggregated_unit_price, adapt_on_names=True) - + Above, functions on ``aggregated_unit_price`` which refer to ``.price`` will return the ``fund.sum(UnitPrice.price).label('price')`` column, as it is matched on the name "price". Ordinarily, the "price" function wouldn't have any "column correspondence" to the actual ``UnitPrice.price`` column as it is not a proxy of the original. - + ``adapt_on_names`` is new in 0.7.3. - + """ - def __init__(self, cls, alias=None, name=None, adapt_on_names=False): + def __init__(self, cls, alias=None, + name=None, + adapt_on_names=False, + with_polymorphic_mappers=(), + with_polymorphic_discriminator=None): self.__mapper = _class_to_mapper(cls) self.__target = self.__mapper.class_ self.__adapt_on_names = adapt_on_names if alias is None: - alias = self.__mapper._with_polymorphic_selectable.alias(name=name) + alias = self.__mapper._with_polymorphic_selectable.alias( + name=name) self.__adapter = sql_util.ClauseAdapter(alias, - equivalents=self.__mapper._equivalent_columns, - adapt_on_names=self.__adapt_on_names) + equivalents=self.__mapper._equivalent_columns, + adapt_on_names=self.__adapt_on_names) self.__alias = alias + self.__with_polymorphic_mappers = with_polymorphic_mappers + self.__with_polymorphic_discriminator = \ + with_polymorphic_discriminator + for poly in with_polymorphic_mappers: + setattr(self, poly.class_.__name__, + AliasedClass(poly.class_, alias)) + # used to assign a name to the RowTuple object # returned by Query. self._sa_label_name = name @@ -303,6 +327,10 @@ class AliasedClass(object): 'alias':self.__alias, 'name':self._sa_label_name, 'adapt_on_names':self.__adapt_on_names, + 'with_polymorphic_mappers': + self.__with_polymorphic_mappers, + 'with_polymorphic_discriminator': + self.__with_polymorphic_discriminator } def __setstate__(self, state): @@ -311,9 +339,13 @@ class AliasedClass(object): self.__adapt_on_names = state['adapt_on_names'] alias = state['alias'] self.__adapter = sql_util.ClauseAdapter(alias, - equivalents=self.__mapper._equivalent_columns, - adapt_on_names=self.__adapt_on_names) + equivalents=self.__mapper._equivalent_columns, + adapt_on_names=self.__adapt_on_names) self.__alias = alias + self.__with_polymorphic_mappers = \ + state.get('with_polymorphic_mappers') + self.__with_polymorphic_discriminator = \ + state.get('with_polymorphic_discriminator') name = state['name'] self._sa_label_name = name self.__name__ = 'AliasedClass_' + str(self.__target) @@ -367,10 +399,75 @@ class AliasedClass(object): def aliased(element, alias=None, name=None, adapt_on_names=False): if isinstance(element, expression.FromClause): if adapt_on_names: - raise sa_exc.ArgumentError("adapt_on_names only applies to ORM elements") + raise sa_exc.ArgumentError( + "adapt_on_names only applies to ORM elements" + ) return element.alias(name) else: - return AliasedClass(element, alias=alias, name=name, adapt_on_names=adapt_on_names) + return AliasedClass(element, alias=alias, + name=name, adapt_on_names=adapt_on_names) + +def with_polymorphic(base, classes, selectable=False, + polymorphic_on=None, aliased=False): + """Produce an :class:`.AliasedClass` construct which specifies + columns for descendant mappers of the given base. + + .. note:: + + :func:`.orm.with_polymorphic` is new in version 0.8. + It is in addition to the existing :class:`.Query` method + :meth:`.Query.with_polymorphic`, which has the same purpose + but is not as flexible in its usage. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + See the examples at :ref:`with_polymorphic`. + + :param base: Base class to be aliased. + + :param cls_or_mappers: a single class or mapper, or list of + class/mappers, which inherit from the base class. + Alternatively, it may also be the string ``'*'``, in which case + all descending mapped classes will be added to the FROM clause. + + :param aliased: when True, the selectable will be wrapped in an + alias, that is ``(SELECT * FROM <fromclauses>) AS anon_1``. + This can be important when using the with_polymorphic() + to create the target of a JOIN on a backend that does not + support parenthesized joins, such as SQLite and older + versions of MySQL. + + :param selectable: a table or select() statement that will + be used in place of the generated FROM clause. This argument is + required if any of the desired classes use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` argument + must represent the full set of tables and columns mapped by every + mapped class. Otherwise, the unaccounted mapped columns will + result in their table being appended directly to the FROM clause + which will usually lead to incorrect results. + + :param polymorphic_on: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the base classes' mapper will be used, if any. This + is useful for mappings that don't have polymorphic loading + behavior by default. + + """ + primary_mapper = class_mapper(base) + mappers, selectable = primary_mapper.\ + _with_polymorphic_args(classes, selectable) + if aliased: + selectable = selectable.alias() + return AliasedClass(base, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=polymorphic_on) + def _orm_annotate(element, exclude=None): """Deep copy the given ClauseElement, annotating each element with the @@ -381,7 +478,21 @@ def _orm_annotate(element, exclude=None): """ return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude) -_orm_deannotate = sql_util._deep_deannotate +def _orm_deannotate(element): + """Remove annotations that link a column to a particular mapping. + + Note this doesn't affect "remote" and "foreign" annotations + passed by the :func:`.orm.foreign` and :func:`.orm.remote` + annotators. + + """ + + return sql_util._deep_deannotate(element, + values=("_orm_adapt", "parententity") + ) + +def _orm_full_deannotate(element): + return sql_util._deep_deannotate(element) class _ORMJoin(expression.Join): """Extend Join to support ORM constructs as input.""" @@ -447,7 +558,7 @@ class _ORMJoin(expression.Join): def join(left, right, onclause=None, isouter=False, join_to_left=True): """Produce an inner join between left and right clauses. - + :func:`.orm.join` is an extension to the core join interface provided by :func:`.sql.expression.join()`, where the left and right selectables may be not only core selectable @@ -460,7 +571,7 @@ def join(left, right, onclause=None, isouter=False, join_to_left=True): in whatever form it is passed, to the selectable passed as the left side. If False, the onclause is used as is. - + :func:`.orm.join` is not commonly needed in modern usage, as its functionality is encapsulated within that of the :meth:`.Query.join` method, which features a @@ -468,22 +579,22 @@ def join(left, right, onclause=None, isouter=False, join_to_left=True): by itself. Explicit usage of :func:`.orm.join` with :class:`.Query` involves usage of the :meth:`.Query.select_from` method, as in:: - + from sqlalchemy.orm import join session.query(User).\\ select_from(join(User, Address, User.addresses)).\\ filter(Address.email_address=='foo@bar.com') - + In modern SQLAlchemy the above join can be written more succinctly as:: - + session.query(User).\\ join(User.addresses).\\ filter(Address.email_address=='foo@bar.com') See :meth:`.Query.join` for information on modern usage of ORM level joins. - + """ return _ORMJoin(left, right, onclause, isouter, join_to_left) @@ -534,19 +645,13 @@ def with_parent(instance, prop): value_is_parent=True) -def _entity_info(entity, compile=True): - """Return mapping information given a class, mapper, or AliasedClass. - - Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this - is an aliased() construct. - - If the given entity is not a mapper, mapped class, or aliased construct, - returns None, the entity, False. This is typically used to allow - unmapped selectables through. - - """ +def _extended_entity_info(entity, compile=True): if isinstance(entity, AliasedClass): - return entity._AliasedClass__mapper, entity._AliasedClass__alias, True + return entity._AliasedClass__mapper, \ + entity._AliasedClass__alias, \ + True, \ + entity._AliasedClass__with_polymorphic_mappers, \ + entity._AliasedClass__with_polymorphic_discriminator if isinstance(entity, mapperlib.Mapper): mapper = entity @@ -555,15 +660,32 @@ def _entity_info(entity, compile=True): class_manager = attributes.manager_of_class(entity) if class_manager is None: - return None, entity, False + return None, entity, False, [], None mapper = class_manager.mapper else: - return None, entity, False + return None, entity, False, [], None if compile and mapperlib.module._new_mappers: mapperlib.configure_mappers() - return mapper, mapper._with_polymorphic_selectable, False + return mapper, \ + mapper._with_polymorphic_selectable, \ + False, \ + mapper._with_polymorphic_mappers, \ + mapper.polymorphic_on + +def _entity_info(entity, compile=True): + """Return mapping information given a class, mapper, or AliasedClass. + + Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this + is an aliased() construct. + + If the given entity is not a mapper, mapped class, or aliased construct, + returns None, the entity, False. This is typically used to allow + unmapped selectables through. + + """ + return _extended_entity_info(entity, compile)[0:3] def _entity_descriptor(entity, key): """Return a class attribute given an entity and string name. @@ -616,15 +738,32 @@ def object_mapper(instance): Raises UnmappedInstanceError if no mapping is configured. + This function is available via the inspection system as:: + + inspect(instance).mapper + + """ + return object_state(instance).mapper + +@inspection._inspects(object) +def object_state(instance): + """Given an object, return the primary Mapper associated with the object + instance. + + Raises UnmappedInstanceError if no mapping is configured. + + This function is available via the inspection system as:: + + inspect(instance) + """ try: - state = attributes.instance_state(instance) - return state.manager.mapper - except exc.UnmappedClassError: - raise exc.UnmappedInstanceError(instance) - except exc.NO_STATE: + return attributes.instance_state(instance) + except (exc.UnmappedClassError, exc.NO_STATE): raise exc.UnmappedInstanceError(instance) + +@inspection._inspects(type) def class_mapper(class_, compile=True): """Given a class, return the primary :class:`.Mapper` associated with the key. @@ -633,6 +772,10 @@ def class_mapper(class_, compile=True): on the given class, or :class:`.ArgumentError` if a non-class object is passed. + This function is available via the inspection system as:: + + inspect(some_mapped_class) + """ try: diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py index c4bac2834..a3adbe293 100644 --- a/lib/sqlalchemy/processors.py +++ b/lib/sqlalchemy/processors.py @@ -20,6 +20,7 @@ def str_to_datetime_processor_factory(regexp, type_): rmatch = regexp.match # Even on python2.6 datetime.strptime is both slower than this code # and it does not support microseconds. + has_named_groups = bool(regexp.groupindex) def process(value): if value is None: return None @@ -32,7 +33,12 @@ def str_to_datetime_processor_factory(regexp, type_): if m is None: raise ValueError("Couldn't parse %s string: " "'%s'" % (type_.__name__ , value)) - return type_(*map(int, m.groups(0))) + if has_named_groups: + groups = m.groupdict(0) + return type_(**dict(zip(groups.iterkeys(), + map(int, groups.itervalues())))) + else: + return type_(*map(int, m.groups(0))) return process def boolean_to_int(value): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 9746af228..f710ae736 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -1084,7 +1084,7 @@ class Column(SchemaItem, expression.ColumnClause): c.dispatch._update(self.dispatch) return c - def _make_proxy(self, selectable, name=None): + def _make_proxy(self, selectable, name=None, key=None): """Create a *proxy* for this column. This is a copy of this ``Column`` referenced by a different parent @@ -1102,7 +1102,7 @@ class Column(SchemaItem, expression.ColumnClause): c = self._constructor( expression._as_truncated(name or self.name), self.type, - key = name or self.key, + key = key if key else name if name else self.key, primary_key = self.primary_key, nullable = self.nullable, quote=self.quote, _proxies=[self], *fk) @@ -1980,6 +1980,14 @@ class CheckConstraint(Constraint): self.sqltext = expression._literal_as_text(sqltext) if table is not None: self._set_parent_with_dispatch(table) + else: + cols = sqlutil.find_columns(self.sqltext) + tables = set([c.table for c in cols + if c.table is not None]) + if len(tables) == 1: + self._set_parent_with_dispatch( + tables.pop()) + def __visit_name__(self): if isinstance(self.parent, Table): @@ -2083,6 +2091,11 @@ class ForeignKeyConstraint(Constraint): if table is not None: self._set_parent_with_dispatch(table) + elif columns and \ + isinstance(columns[0], Column) and \ + columns[0].table is not None: + self._set_parent_with_dispatch(columns[0].table) + @property def columns(self): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index fdff99fb1..218e48bca 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -365,7 +365,9 @@ class SQLCompiler(engine.Compiled): labelname = label.name if result_map is not None: - result_map[labelname.lower()] = ( + result_map[labelname + if self.dialect.case_sensitive + else labelname.lower()] = ( label.name, (label, label.element, labelname, ) + label._alt_names, @@ -393,7 +395,9 @@ class SQLCompiler(engine.Compiled): name = self._truncated_identifier("colident", name) if result_map is not None: - result_map[name.lower()] = (orig_name, + result_map[name + if self.dialect.case_sensitive + else name.lower()] = (orig_name, (column, name, column.key), column.type) @@ -441,7 +445,10 @@ class SQLCompiler(engine.Compiled): def visit_textclause(self, textclause, **kwargs): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): - self.result_map[colname.lower()] = (colname, None, type_) + self.result_map[colname + if self.dialect.case_sensitive + else colname.lower()] = \ + (colname, None, type_) def do_bindparam(m): name = m.group(1) @@ -518,7 +525,10 @@ class SQLCompiler(engine.Compiled): def visit_function(self, func, result_map=None, **kwargs): if result_map is not None: - result_map[func.name.lower()] = (func.name, None, func.type) + result_map[func.name + if self.dialect.case_sensitive + else func.name.lower()] = \ + (func.name, None, func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: @@ -1115,7 +1125,7 @@ class SQLCompiler(engine.Compiled): """Provide a hook to override the generation of an UPDATE..FROM clause. - MySQL overrides this. + MySQL and MSSQL override this. """ return "FROM " + ', '.join( diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f37faa801..a0f0bab6c 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1584,18 +1584,35 @@ class ClauseElement(Visitable): return id(self) def _annotate(self, values): - """return a copy of this ClauseElement with the given annotations - dictionary. + """return a copy of this ClauseElement with annotations + updated by the given dictionary. """ return sqlutil.Annotated(self, values) - def _deannotate(self): - """return a copy of this ClauseElement with an empty annotations - dictionary. + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. """ - return self._clone() + return sqlutil.Annotated(self, values) + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone: + # clone is used when we are also copying + # the expression for a deep deannotation + return self._clone() + else: + # if no clone, since we have no annotations we return + # self + return self def unique_params(self, *optionaldict, **kwargs): """Return a copy with :func:`bindparam()` elements replaced. @@ -2146,7 +2163,7 @@ class ColumnElement(ClauseElement, _CompareMixin): return hasattr(other, 'name') and hasattr(self, 'name') and \ other.name == self.name - def _make_proxy(self, selectable, name=None): + def _make_proxy(self, selectable, name=None, **kw): """Create a new :class:`.ColumnElement` representing this :class:`.ColumnElement` as it appears in the select list of a descending selectable. @@ -2154,14 +2171,10 @@ class ColumnElement(ClauseElement, _CompareMixin): """ if name is None: name = self.anon_label - # TODO: may want to change this to anon_label, - # or some value that is more useful than the - # compiled form of the expression key = str(self) else: key = name - - co = ColumnClause(_as_truncated(name), + co = ColumnClause(_as_truncated(name), selectable, type_=getattr(self, 'type', None)) @@ -2195,7 +2208,7 @@ class ColumnElement(ClauseElement, _CompareMixin): for oth in to_compare: if use_proxies and self.shares_lineage(oth): return True - elif oth is self: + elif hash(oth) == hash(self): return True else: return False @@ -3403,6 +3416,10 @@ class _BinaryExpression(ColumnElement): raise TypeError("Boolean value of this clause is not defined") @property + def is_comparison(self): + return operators.is_comparison(self.operator) + + @property def _from_objects(self): return self.left._from_objects + self.right._from_objects @@ -3960,8 +3977,9 @@ class _Label(ColumnElement): def _from_objects(self): return self.element._from_objects - def _make_proxy(self, selectable, name = None): - e = self.element._make_proxy(selectable, name=name or self.name) + def _make_proxy(self, selectable, name=None, **kw): + e = self.element._make_proxy(selectable, + name=name if name else self.name) e.proxies.append(self) return e @@ -4102,18 +4120,6 @@ class ColumnClause(_Immutable, ColumnElement): else: return name - def label(self, name): - # currently, anonymous labels don't occur for - # ColumnClause. The use at the moment - # is that they do not generate nicely for - # is_literal clauses. We would like to change - # this so that label(None) acts as would be expected. - # See [ticket:2168]. - if name is None: - return self - else: - return super(ColumnClause, self).label(name) - def _bind_param(self, operator, obj): return _BindParamClause(self.name, obj, @@ -4121,12 +4127,12 @@ class ColumnClause(_Immutable, ColumnElement): _compared_to_type=self.type, unique=True) - def _make_proxy(self, selectable, name=None, attach=True): + def _make_proxy(self, selectable, name=None, attach=True, **kw): # propagate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = self._constructor( - _as_truncated(name or self.name), + _as_truncated(name if name else self.name), selectable=selectable, type_=self.type, is_literal=is_literal @@ -4137,7 +4143,7 @@ class ColumnClause(_Immutable, ColumnElement): selectable._is_clone_of.columns[c.name] if attach: - selectable._columns[c.name] = c + selectable._columns[c.key] = c return c class TableClause(_Immutable, FromClause): @@ -4563,8 +4569,9 @@ class _ScalarSelect(_Grouping): def self_group(self, **kwargs): return self - def _make_proxy(self, selectable, name): - return list(self.inner_columns)[0]._make_proxy(selectable, name) + def _make_proxy(self, selectable, name=None, **kw): + return list(self.inner_columns)[0]._make_proxy( + selectable, name=name) class CompoundSelect(_SelectBase): """Forms the basis of ``UNION``, ``UNION ALL``, and other @@ -4628,8 +4635,9 @@ class CompoundSelect(_SelectBase): # ForeignKeys in. this would allow the union() to have all # those fks too. - proxy = cols[0]._make_proxy(self, name=self.use_labels - and cols[0]._label or None) + proxy = cols[0]._make_proxy(self, + name=cols[0]._label if self.use_labels else None, + key=cols[0]._key_label if self.use_labels else None) # hand-construct the "proxies" collection to include all # derived columns place a 'weight' annotation corresponding @@ -4802,6 +4810,15 @@ class Select(_SelectBase): toremove = set(itertools.chain(*[f._hide_froms for f in froms])) if toremove: + # if we're maintaining clones of froms, + # add the copies out to the toremove list + if self._from_cloned: + toremove.update( + self._from_cloned[f] for f in + toremove.intersection(self._from_cloned) + ) + # filter out to FROM clauses not in the list, + # using a list to maintain ordering froms = [f for f in froms if f not in toremove] if len(froms) > 1 or self._correlate: @@ -5238,8 +5255,8 @@ class Select(_SelectBase): for c in self.inner_columns: if hasattr(c, '_make_proxy'): c._make_proxy(self, - name=self.use_labels - and c._label or None) + name=c._label if self.use_labels else None, + key=c._key_label if self.use_labels else None) def self_group(self, against=None): """return a 'grouping' construct as per the ClauseElement diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 89f0aaee1..b86b50db4 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -521,6 +521,11 @@ def nullslast_op(a): _commutative = set([eq, ne, add, mul]) +_comparison = set([eq, ne, lt, gt, ge, le]) + +def is_comparison(op): + return op in _comparison + def is_commutative(op): return op in _commutative diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 8d2b5ecfd..cb8359048 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -62,6 +62,65 @@ def find_join_source(clauses, join_to): else: return None, None + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a <eq> q + a <eq> e + a <eq> f + b <eq> q + b <eq> e + b <eq> f + j <eq> r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + def visit(element): + if isinstance(element, (expression._ScalarSelect)): + # we dont want to dig into correlated subqueries, + # those are just column elements by themselves + yield element + elif element.__visit_name__ == 'binary' and \ + operators.is_comparison(element.operator): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, expression.ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + list(visit(expr)) + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False, include_crud=False): @@ -225,7 +284,10 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) -def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): + +def join_condition(a, b, ignore_nonexistent_tables=False, + a_subset=None, + consider_as_foreign_keys=None): """create a join condition between two tables or selectables. e.g.:: @@ -261,6 +323,9 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): for fk in sorted( b.foreign_keys, key=lambda fk:fk.parent._creation_order): + if consider_as_foreign_keys is not None and \ + fk.parent not in consider_as_foreign_keys: + continue try: col = fk.get_referent(left) except exc.NoReferenceError, nrte: @@ -276,6 +341,9 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): for fk in sorted( left.foreign_keys, key=lambda fk:fk.parent._creation_order): + if consider_as_foreign_keys is not None and \ + fk.parent not in consider_as_foreign_keys: + continue try: col = fk.get_referent(b) except exc.NoReferenceError, nrte: @@ -298,11 +366,11 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): "subquery using alias()?" else: hint = "" - raise exc.ArgumentError( + raise exc.NoForeignKeysError( "Can't find any foreign key relationships " "between '%s' and '%s'.%s" % (a.description, b.description, hint)) elif len(constraints) > 1: - raise exc.ArgumentError( + raise exc.AmbiguousForeignKeysError( "Can't determine join between '%s' and '%s'; " "tables have more than one foreign key " "constraint relationship between them. " @@ -356,13 +424,22 @@ class Annotated(object): def _annotate(self, values): _values = self._annotations.copy() _values.update(values) + return self._with_annotations(_values) + + def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() - clone._annotations = _values + clone._annotations = values return clone - def _deannotate(self): - return self.__element + def _deannotate(self, values=None, clone=True): + if values is None: + return self.__element + else: + _values = self._annotations.copy() + for v in values: + _values.pop(v, None) + return self._with_annotations(_values) def _compiler_dispatch(self, visitor, **kw): return self.__element.__class__._compiler_dispatch(self, visitor, **kw) @@ -410,14 +487,8 @@ def _deep_annotate(element, annotations, exclude=None): Elements within the exclude collection will be cloned but not annotated. """ - cloned = util.column_dict() - def clone(elem): - # check if element is present in the exclude list. - # take into account proxying relationships. - if elem in cloned: - return cloned[elem] - elif exclude and \ + if exclude and \ hasattr(elem, 'proxy_set') and \ elem.proxy_set.intersection(exclude): newelem = elem._clone() @@ -426,24 +497,32 @@ def _deep_annotate(element, annotations, exclude=None): else: newelem = elem newelem._copy_internals(clone=clone) - cloned[elem] = newelem return newelem if element is not None: element = clone(element) return element -def _deep_deannotate(element): - """Deep copy the given element, removing all annotations.""" +def _deep_deannotate(element, values=None): + """Deep copy the given element, removing annotations.""" cloned = util.column_dict() def clone(elem): - if elem not in cloned: - newelem = elem._deannotate() + # if a values dict is given, + # the elem must be cloned each time it appears, + # as there may be different annotations in source + # elements that are remaining. if totally + # removing all annotations, can assume the same + # slate... + if values or elem not in cloned: + newelem = elem._deannotate(values=values, clone=True) newelem._copy_internals(clone=clone) - cloned[elem] = newelem - return cloned[elem] + if not values: + cloned[elem] = newelem + return newelem + else: + return cloned[elem] if element is not None: element = clone(element) @@ -547,6 +626,10 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, "'consider_as_foreign_keys' or " "'consider_as_referenced_keys'") + def col_is(a, b): + #return a is b + return a.compare(b) + def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return @@ -556,20 +639,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, if consider_as_foreign_keys: if binary.left in consider_as_foreign_keys and \ - (binary.right is binary.left or + (col_is(binary.right, binary.left) or binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) elif binary.right in consider_as_foreign_keys and \ - (binary.left is binary.right or + (col_is(binary.left, binary.right) or binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: if binary.left in consider_as_referenced_keys and \ - (binary.right is binary.left or + (col_is(binary.right, binary.left) or binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) elif binary.right in consider_as_referenced_keys and \ - (binary.left is binary.right or + (col_is(binary.left, binary.right) or binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: @@ -681,11 +764,22 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): s.c.col1 == table2.c.col1 """ - def __init__(self, selectable, equivalents=None, include=None, exclude=None, adapt_on_names=False): + def __init__(self, selectable, equivalents=None, + include=None, exclude=None, + include_fn=None, exclude_fn=None, + adapt_on_names=False): self.__traverse_options__ = {'stop_on':[selectable]} self.selectable = selectable - self.include = include - self.exclude = exclude + if include: + assert not include_fn + self.include_fn = lambda e: e in include + else: + self.include_fn = include_fn + if exclude: + assert not exclude_fn + self.exclude_fn = lambda e: e in exclude + else: + self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names @@ -705,19 +799,17 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return newcol def replace(self, col): - if isinstance(col, expression.FromClause): - if self.selectable.is_derived_from(col): + if isinstance(col, expression.FromClause) and \ + self.selectable.is_derived_from(col): return self.selectable - - if not isinstance(col, expression.ColumnElement): + elif not isinstance(col, expression.ColumnElement): return None - - if self.include and col not in self.include: + elif self.include_fn and not self.include_fn(col): return None - elif self.exclude and col in self.exclude: + elif self.exclude_fn and self.exclude_fn(col): return None - - return self._corresponding_column(col, True) + else: + return self._corresponding_column(col, True) class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 5354fbcbb..8a06982fc 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -242,13 +242,13 @@ def cloned_traverse(obj, opts, visitors): if elem in stop_on: return elem else: - if elem not in cloned: - cloned[elem] = newelem = elem._clone() + if id(elem) not in cloned: + cloned[id(elem)] = newelem = elem._clone() newelem._copy_internals(clone=clone) meth = visitors.get(newelem.__visit_name__, None) if meth: meth(newelem) - return cloned[elem] + return cloned[id(elem)] if obj is not None: obj = clone(obj) @@ -260,16 +260,16 @@ def replacement_traverse(obj, opts, replace): replacement by a given replacement function.""" cloned = util.column_dict() - stop_on = util.column_set(opts.get('stop_on', [])) + stop_on = util.column_set([id(x) for x in opts.get('stop_on', [])]) def clone(elem, **kw): - if elem in stop_on or \ + if id(elem) in stop_on or \ 'no_replacement_traverse' in elem._annotations: return elem else: newelem = replace(elem) if newelem is not None: - stop_on.add(newelem) + stop_on.add(id(newelem)) return newelem else: if elem not in cloned: diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 512ac6268..c6e3ee3d6 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -17,7 +17,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', 'CLOB', 'BLOB', 'BOOLEAN', 'SMALLINT', 'INTEGER', 'DATE', 'TIME', 'String', 'Integer', 'SmallInteger', 'BigInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', 'Binary', - 'Boolean', 'Unicode', 'MutableType', 'Concatenable', + 'Boolean', 'Unicode', 'Concatenable', 'UnicodeText','PickleType', 'Interval', 'Enum' ] import inspect @@ -83,28 +83,6 @@ class TypeEngine(AbstractType): return x == y - def is_mutable(self): - """Return True if the target Python type is 'mutable'. - - This allows systems like the ORM to know if a column value can - be considered 'not changed' by comparing the identity of - objects alone. Values such as dicts, lists which - are serialized into strings are examples of "mutable" - column structures. - - .. note:: - - This functionality is now superseded by the - ``sqlalchemy.ext.mutable`` extension described in - :ref:`mutable_toplevel`. - - When this method is overridden, :meth:`copy_value` should - also be supplied. The :class:`.MutableType` mixin - is recommended as a helper. - - """ - return False - def get_dbapi_type(self, dbapi): """Return the corresponding type object from the underlying DB-API, if any. @@ -710,30 +688,6 @@ class TypeDecorator(TypeEngine): """ return self.impl.get_dbapi_type(dbapi) - def copy_value(self, value): - """Given a value, produce a copy of it. - - By default this calls upon :meth:`.TypeEngine.copy_value` - of the underlying "impl". - - :meth:`.copy_value` will return the object - itself, assuming "mutability" is not enabled. - Only the :class:`.MutableType` mixin provides a copy - function that actually produces a new object. - The copying function is used by the ORM when - "mutable" types are used, to memoize the original - version of an object as loaded from the database, - which is then compared to the possibly mutated - version to check for changes. - - Modern implementations should use the - ``sqlalchemy.ext.mutable`` extension described in - :ref:`mutable_toplevel` for intercepting in-place - changes to values. - - """ - return self.impl.copy_value(value) - def compare_values(self, x, y): """Given two values, compare them for equality. @@ -749,24 +703,6 @@ class TypeDecorator(TypeEngine): """ return self.impl.compare_values(x, y) - def is_mutable(self): - """Return True if the target Python type is 'mutable'. - - This allows systems like the ORM to know if a column value can - be considered 'not changed' by comparing the identity of - objects alone. Values such as dicts, lists which - are serialized into strings are examples of "mutable" - column structures. - - .. note:: - - This functionality is now superseded by the - ``sqlalchemy.ext.mutable`` extension described in - :ref:`mutable_toplevel`. - - """ - return self.impl.is_mutable() - def _adapt_expression(self, op, othertype): """ #todo @@ -828,82 +764,6 @@ class Variant(TypeDecorator): mapping[dialect_name] = type_ return Variant(self.impl, mapping) -class MutableType(object): - """A mixin that marks a :class:`.TypeEngine` as representing - a mutable Python object type. This functionality is used - only by the ORM. - - .. note:: - - :class:`.MutableType` is superseded as of SQLAlchemy 0.7 - by the ``sqlalchemy.ext.mutable`` extension described in - :ref:`mutable_toplevel`. This extension provides an event - driven approach to in-place mutation detection that does not - incur the severe performance penalty of the :class:`.MutableType` - approach. - - "mutable" means that changes can occur in place to a value - of this type. Examples includes Python lists, dictionaries, - and sets, as well as user-defined objects. The primary - need for identification of "mutable" types is by the ORM, - which applies special rules to such values in order to guarantee - that changes are detected. These rules may have a significant - performance impact, described below. - - A :class:`.MutableType` usually allows a flag called - ``mutable=False`` to enable/disable the "mutability" flag, - represented on this class by :meth:`is_mutable`. Examples - include :class:`.PickleType` and - :class:`~sqlalchemy.dialects.postgresql.base.ARRAY`. Setting - this flag to ``True`` enables mutability-specific behavior - by the ORM. - - The :meth:`copy_value` and :meth:`compare_values` functions - represent a copy and compare function for values of this - type - implementing subclasses should override these - appropriately. - - .. warning:: - - The usage of mutable types has significant performance - implications when using the ORM. In order to detect changes, the - ORM must create a copy of the value when it is first - accessed, so that changes to the current value can be compared - against the "clean" database-loaded value. Additionally, when the - ORM checks to see if any data requires flushing, it must scan - through all instances in the session which are known to have - "mutable" attributes and compare the current value of each - one to its "clean" - value. So for example, if the Session contains 6000 objects (a - fairly large amount) and autoflush is enabled, every individual - execution of :class:`.Query` will require a full scan of that subset of - the 6000 objects that have mutable attributes, possibly resulting - in tens of thousands of additional method calls for every query. - - As of SQLAlchemy 0.7, the ``sqlalchemy.ext.mutable`` is provided which - allows an event driven approach to in-place mutation detection. This - approach should now be favored over the usage of :class:`.MutableType` - with ``mutable=True``. ``sqlalchemy.ext.mutable`` is described in - :ref:`mutable_toplevel`. - - """ - - def is_mutable(self): - """Return True if the target Python type is 'mutable'. - - For :class:`.MutableType`, this method is set to - return ``True``. - - """ - return True - - def copy_value(self, value): - """Unimplemented.""" - raise NotImplementedError() - - def compare_values(self, x, y): - """Compare *x* == *y*.""" - return x == y def to_instance(typeobj, *arg, **kw): if typeobj is None: @@ -1961,7 +1821,7 @@ class Enum(String, SchemaType): else: return super(Enum, self).adapt(impltype, **kw) -class PickleType(MutableType, TypeDecorator): +class PickleType(TypeDecorator): """Holds Python objects, which are serialized using pickle. PickleType builds upon the Binary type to apply Python's @@ -1969,12 +1829,15 @@ class PickleType(MutableType, TypeDecorator): the way out, allowing any pickleable Python object to be stored as a serialized binary field. + To allow ORM change events to propagate for elements associated + with :class:`.PickleType`, see :ref:`mutable_toplevel`. + """ impl = LargeBinary def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, - pickler=None, mutable=False, comparator=None): + pickler=None, comparator=None): """ Construct a PickleType. @@ -1984,21 +1847,6 @@ class PickleType(MutableType, TypeDecorator): cPickle is not available. May be any object with pickle-compatible ``dumps` and ``loads`` methods. - :param mutable: defaults to False; implements - :meth:`AbstractType.is_mutable`. When ``True``, incoming - objects will be compared against copies of themselves - using the Python "equals" operator, unless the - ``comparator`` argument is present. See - :class:`.MutableType` for details on "mutable" type - behavior. (default changed from ``True`` in - 0.7.0). - - .. note:: - - This functionality is now superseded by the - ``sqlalchemy.ext.mutable`` extension described in - :ref:`mutable_toplevel`. - :param comparator: a 2-arg callable predicate used to compare values of this type. If left as ``None``, the Python "equals" operator is used to compare values. @@ -2006,14 +1854,12 @@ class PickleType(MutableType, TypeDecorator): """ self.protocol = protocol self.pickler = pickler or pickle - self.mutable = mutable self.comparator = comparator super(PickleType, self).__init__() def __reduce__(self): return PickleType, (self.protocol, None, - self.mutable, self.comparator) def bind_processor(self, dialect): @@ -2048,29 +1894,12 @@ class PickleType(MutableType, TypeDecorator): return loads(value) return process - def copy_value(self, value): - if self.mutable: - return self.pickler.loads( - self.pickler.dumps(value, self.protocol)) - else: - return value - def compare_values(self, x, y): if self.comparator: return self.comparator(x, y) else: return x == y - def is_mutable(self): - """Return True if the target Python type is 'mutable'. - - When this method is overridden, :meth:`copy_value` should - also be supplied. The :class:`.MutableType` mixin - is recommended as a helper. - - """ - return self.mutable - class Boolean(TypeEngine, SchemaType): """A bool datatype. diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 13914aa7d..76c3c829d 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -27,7 +27,7 @@ from langhelpers import iterate_attributes, class_hierarchy, \ duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\ classproperty, set_creation_order, warn_exception, warn, NoneType,\ constructor_copy, methods_equivalent, chop_traceback, asint,\ - generic_repr, counter + generic_repr, counter, PluginLoader from deprecations import warn_deprecated, warn_pending_deprecation, \ deprecated, pending_deprecation diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index b6c89b11a..9e5b0e4ad 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -52,6 +52,45 @@ def decorator(target): return update_wrapper(decorated, fn) return update_wrapper(decorate, target) +class PluginLoader(object): + def __init__(self, group, auto_fn=None): + self.group = group + self.impls = {} + self.auto_fn = auto_fn + + def load(self, name): + if name in self.impls: + return self.impls[name]() + + if self.auto_fn: + loader = self.auto_fn(name) + if loader: + self.impls[name] = loader + return loader() + + try: + import pkg_resources + except ImportError: + pass + else: + for impl in pkg_resources.iter_entry_points( + self.group, name): + self.impls[name] = impl.load + return impl.load() + + from sqlalchemy import exc + raise exc.ArgumentError( + "Can't load plugin: %s:%s" % + (self.group, name)) + + def register(self, name, modulepath, objname): + def load(): + mod = __import__(modulepath) + for token in modulepath.split(".")[1:]: + mod = getattr(mod, token) + return getattr(mod, objname) + self.impls[name] = load + def get_cls_kwargs(cls): """Return the full set of inherited kwargs for the given `cls`. @@ -783,15 +822,21 @@ class classproperty(property): return desc.fget(cls) -class _symbol(object): - def __init__(self, name, doc=None): +class _symbol(int): + def __new__(self, name, doc=None, canonical=None): """Construct a new named symbol.""" assert isinstance(name, str) - self.name = name + if canonical is None: + canonical = hash(name) + v = int.__new__(_symbol, canonical) + v.name = name if doc: - self.__doc__ = doc + v.__doc__ = doc + return v + def __reduce__(self): - return symbol, (self.name,) + return symbol, (self.name, "x", int(self)) + def __repr__(self): return "<symbol '%s>" % self.name @@ -822,12 +867,12 @@ class symbol(object): symbols = {} _lock = threading.Lock() - def __new__(cls, name, doc=None): + def __new__(cls, name, doc=None, canonical=None): cls._lock.acquire() try: sym = cls.symbols.get(name) if sym is None: - cls.symbols[name] = sym = _symbol(name, doc) + cls.symbols[name] = sym = _symbol(name, doc, canonical) return sym finally: symbol._lock.release() |
