diff options
| author | Alex Gaynor <alex.gaynor@gmail.com> | 2013-07-08 10:39:54 +1000 |
|---|---|---|
| committer | Alex Gaynor <alex.gaynor@gmail.com> | 2013-07-08 10:39:54 +1000 |
| commit | 03d9566e0df45c72bffa99fe244a92f0279da56f (patch) | |
| tree | b4c9cc9260769d7f79f62c0fe16a2c5a5746e6ed | |
| parent | 0b69a755029f8a8ba6005b9d77be8132a7bc0fb3 (diff) | |
| download | django-03d9566e0df45c72bffa99fe244a92f0279da56f.tar.gz | |
A large number of stylistic cleanups across django/db/
48 files changed, 382 insertions, 194 deletions
diff --git a/django/db/__init__.py b/django/db/__init__.py index 72adf60d66..e1af430dc2 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -1,20 +1,25 @@ import warnings from django.core import signals -from django.db.utils import (DEFAULT_DB_ALIAS, - DataError, OperationalError, IntegrityError, InternalError, - ProgrammingError, NotSupportedError, DatabaseError, - InterfaceError, Error, - load_backend, ConnectionHandler, ConnectionRouter) +from django.db.utils import (DEFAULT_DB_ALIAS, DataError, OperationalError, + IntegrityError, InternalError, ProgrammingError, NotSupportedError, + DatabaseError, InterfaceError, Error, load_backend, + ConnectionHandler, ConnectionRouter) from django.utils.functional import cached_property -__all__ = ('backend', 'connection', 'connections', 'router', 'DatabaseError', - 'IntegrityError', 'DEFAULT_DB_ALIAS') + +__all__ = [ + 'backend', 'connection', 'connections', 'router', 'DatabaseError', + 'IntegrityError', 'InternalError', 'ProgrammingError', 'DataError', + 'NotSupportedError', 'Error', 'InterfaceError', 'OperationalError', + 'DEFAULT_DB_ALIAS' +] connections = ConnectionHandler() router = ConnectionRouter() + # `connection`, `DatabaseError` and `IntegrityError` are convenient aliases # for backend bits. @@ -70,6 +75,7 @@ class DefaultBackendProxy(object): backend = DefaultBackendProxy() + def close_connection(**kwargs): warnings.warn( "close_connection is superseded by close_old_connections.", @@ -83,12 +89,14 @@ def close_connection(**kwargs): transaction.abort(conn) connections[conn].close() + # Register an event to reset saved queries when a Django request is started. def reset_queries(**kwargs): for conn in connections.all(): conn.queries = [] signals.request_started.connect(reset_queries) + # Register an event to reset transaction state and close connections past # their lifetime. NB: abort() doesn't do anything outside of a transaction. def close_old_connections(**kwargs): diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 28a4a48a9f..a553351c5c 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -1167,6 +1167,7 @@ FieldInfo = namedtuple('FieldInfo', 'name type_code display_size internal_size precision scale null_ok' ) + class BaseDatabaseIntrospection(object): """ This class encapsulates all backend-specific introspection utilities diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index 6f66cfb7ca..50e45c0e75 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -251,12 +251,13 @@ class BaseDatabaseCreation(object): r_col = model._meta.get_field(f.rel.field_name).column r_name = '%s_refs_%s_%s' % ( col, r_col, self._digest(table, r_table)) - output.append('%s %s %s %s;' % \ - (style.SQL_KEYWORD('ALTER TABLE'), + output.append('%s %s %s %s;' % ( + style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(table)), style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()), style.SQL_FIELD(qn(truncate_name( - r_name, self.connection.ops.max_name_length()))))) + r_name, self.connection.ops.max_name_length()))) + )) del references_to_delete[model] return output diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index 9a220ffd8b..3e6d3e4c5a 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -8,33 +8,43 @@ ImproperlyConfigured. """ from django.core.exceptions import ImproperlyConfigured -from django.db.backends import * +from django.db.backends import (BaseDatabaseOperations, BaseDatabaseClient, + BaseDatabaseIntrospection, BaseDatabaseWrapper, BaseDatabaseFeatures, + BaseDatabaseValidation) from django.db.backends.creation import BaseDatabaseCreation + def complain(*args, **kwargs): raise ImproperlyConfigured("settings.DATABASES is improperly configured. " "Please supply the ENGINE value. Check " "settings documentation for more details.") + def ignore(*args, **kwargs): pass + class DatabaseError(Exception): pass + class IntegrityError(DatabaseError): pass + class DatabaseOperations(BaseDatabaseOperations): quote_name = complain + class DatabaseClient(BaseDatabaseClient): runshell = complain + class DatabaseCreation(BaseDatabaseCreation): create_test_db = ignore destroy_test_db = ignore + class DatabaseIntrospection(BaseDatabaseIntrospection): get_table_list = complain get_table_description = complain @@ -42,6 +52,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): get_indexes = complain get_key_columns = complain + class DatabaseWrapper(BaseDatabaseWrapper): operators = {} # Override the base class implementations with null diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index d10be94f43..2db746f40f 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -36,8 +36,9 @@ except ImportError: pytz = None from django.conf import settings -from django.db import utils -from django.db.backends import * +from django.db import (utils, BaseDatabaseFeatures, BaseDatabaseOperations, + BaseDatabaseWrapper) +from django.db.backends import util from django.db.backends.mysql.client import DatabaseClient from django.db.backends.mysql.creation import DatabaseCreation from django.db.backends.mysql.introspection import DatabaseIntrospection @@ -57,6 +58,7 @@ IntegrityError = Database.IntegrityError # It's impossible to import datetime_or_None directly from MySQLdb.times parse_datetime = conversions[FIELD_TYPE.DATETIME] + def parse_datetime_with_timezone_support(value): dt = parse_datetime(value) # Confirm that dt is naive before overwriting its tzinfo. @@ -64,6 +66,7 @@ def parse_datetime_with_timezone_support(value): dt = dt.replace(tzinfo=timezone.utc) return dt + def adapt_datetime_with_timezone_support(value, conv): # Equivalent to DateTimeField.get_db_prep_value. Used only by raw SQL. if settings.USE_TZ: @@ -98,6 +101,7 @@ django_conversions.update({ # http://dev.mysql.com/doc/refman/5.0/en/news.html . server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})') + # MySQLdb-1.2.1 and newer automatically makes use of SHOW WARNINGS on # MySQL-4.1 and newer, so the MysqlDebugWrapper is unnecessary. Since the # point is to raise Warnings as exceptions, this can be done with the Python @@ -148,6 +152,7 @@ class CursorWrapper(object): def __iter__(self): return iter(self.cursor) + class DatabaseFeatures(BaseDatabaseFeatures): empty_fetchmany_value = () update_can_self_select = False @@ -204,6 +209,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1") return cursor.fetchone() is not None + class DatabaseOperations(BaseDatabaseOperations): compiler_module = "django.db.backends.mysql.compiler" @@ -319,7 +325,7 @@ class DatabaseOperations(BaseDatabaseOperations): # Truncate already resets the AUTO_INCREMENT field from # MySQL version 5.0.13 onwards. Refs #16961. if self.connection.mysql_version < (5, 0, 13): - return ["%s %s %s %s %s;" % \ + return ["%s %s %s %s %s;" % (style.SQL_KEYWORD('ALTER'), style.SQL_KEYWORD('TABLE'), style.SQL_TABLE(self.quote_name(sequence['table'])), @@ -373,6 +379,7 @@ class DatabaseOperations(BaseDatabaseOperations): items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) return "VALUES " + ", ".join([items_sql] * num_values) + class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'mysql' operators = { diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py index 1cf8ceef9c..8364c7b6e6 100644 --- a/django/db/backends/mysql/client.py +++ b/django/db/backends/mysql/client.py @@ -3,6 +3,7 @@ import sys from django.db.backends import BaseDatabaseClient + class DatabaseClient(BaseDatabaseClient): executable_name = 'mysql' @@ -37,4 +38,3 @@ class DatabaseClient(BaseDatabaseClient): sys.exit(os.system(" ".join(args))) else: os.execvp(self.executable_name, args) - diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py index 4e033e3d93..b7d1d7b98d 100644 --- a/django/db/backends/mysql/compiler.py +++ b/django/db/backends/mysql/compiler.py @@ -22,20 +22,26 @@ class SQLCompiler(compiler.SQLCompiler): sql, params = self.as_sql() return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params + class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): pass + class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): pass + class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): pass + class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): pass + class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler): pass + class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler): pass diff --git a/django/db/backends/mysql/creation.py b/django/db/backends/mysql/creation.py index 3a57c29479..b59b225f4f 100644 --- a/django/db/backends/mysql/creation.py +++ b/django/db/backends/mysql/creation.py @@ -1,34 +1,35 @@ from django.db.backends.creation import BaseDatabaseCreation + class DatabaseCreation(BaseDatabaseCreation): # This dictionary maps Field objects to their associated MySQL column # types, as strings. Column-type strings can contain format strings; they'll # be interpolated against the values of Field.__dict__ before being output. # If a column type is set to None, it won't be included in the output. data_types = { - 'AutoField': 'integer AUTO_INCREMENT', - 'BinaryField': 'longblob', - 'BooleanField': 'bool', - 'CharField': 'varchar(%(max_length)s)', + 'AutoField': 'integer AUTO_INCREMENT', + 'BinaryField': 'longblob', + 'BooleanField': 'bool', + 'CharField': 'varchar(%(max_length)s)', 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'datetime', - 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'double precision', - 'IntegerField': 'integer', - 'BigIntegerField': 'bigint', - 'IPAddressField': 'char(15)', + 'DateField': 'date', + 'DateTimeField': 'datetime', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FileField': 'varchar(%(max_length)s)', + 'FilePathField': 'varchar(%(max_length)s)', + 'FloatField': 'double precision', + 'IntegerField': 'integer', + 'BigIntegerField': 'bigint', + 'IPAddressField': 'char(15)', 'GenericIPAddressField': 'char(39)', - 'NullBooleanField': 'bool', - 'OneToOneField': 'integer', + 'NullBooleanField': 'bool', + 'OneToOneField': 'integer', 'PositiveIntegerField': 'integer UNSIGNED', 'PositiveSmallIntegerField': 'smallint UNSIGNED', - 'SlugField': 'varchar(%(max_length)s)', + 'SlugField': 'varchar(%(max_length)s)', 'SmallIntegerField': 'smallint', - 'TextField': 'longtext', - 'TimeField': 'time', + 'TextField': 'longtext', + 'TimeField': 'time', } def sql_table_creation_suffix(self): diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index 548877e8e6..ec9f3e99f8 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -7,6 +7,7 @@ from django.utils.encoding import force_text foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") + class DatabaseIntrospection(BaseDatabaseIntrospection): data_types_reverse = { FIELD_TYPE.BLOB: 'TextField', @@ -116,4 +117,3 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): continue indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])} return indexes - diff --git a/django/db/backends/mysql/validation.py b/django/db/backends/mysql/validation.py index 2ce957cce7..17b7cde756 100644 --- a/django/db/backends/mysql/validation.py +++ b/django/db/backends/mysql/validation.py @@ -1,5 +1,6 @@ from django.db.backends import BaseDatabaseValidation + class DatabaseValidation(BaseDatabaseValidation): def validate_field(self, errors, opts, f): """ diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 5e2b763f52..04e70cfdb8 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -7,11 +7,12 @@ from __future__ import unicode_literals import decimal import re +import platform import sys import warnings + def _setup_environment(environ): - import platform # Cygwin requires some special voodoo to set the environment variables # properly so that Oracle will see them. if platform.system().upper().startswith('CYGWIN'): @@ -90,6 +91,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_tablespaces = True supports_sequence_reset = False + class DatabaseOperations(BaseDatabaseOperations): compiler_module = "django.db.backends.oracle.compiler" @@ -308,7 +310,7 @@ WHEN (new.%(col_name)s IS NULL) # Oracle puts the query text into a (query % args) construct, so % signs # in names need to be escaped. The '%%' will be collapsed back to '%' at # that stage so we aren't really making the name longer here. - name = name.replace('%','%%') + name = name.replace('%', '%%') return name.upper() def random_function_sql(self): @@ -758,7 +760,7 @@ class FormatStylePlaceholderCursor(object): def _format_params(self, params): try: - return dict((k,OracleParam(v, self, True)) for k,v in params.items()) + return dict((k, OracleParam(v, self, True)) for k, v in params.items()) except AttributeError: return tuple([OracleParam(p, self, True) for p in params]) @@ -778,12 +780,12 @@ class FormatStylePlaceholderCursor(object): for i, value in enumerate(params): if value.input_size: sizes[i] = value.input_size - self.setinputsizes(*sizes) + self.setinputsizes(*sizes) def _param_generator(self, params): # Try dict handling; if that fails, treat as sequence if hasattr(params, 'items'): - return dict((k, v.force_bytes) for k,v in params.items()) + return dict((k, v.force_bytes) for k, v in params.items()) else: return [p.force_bytes for p in params] @@ -799,14 +801,14 @@ class FormatStylePlaceholderCursor(object): query = convert_unicode(query, self.charset) elif hasattr(params, 'keys'): # Handle params as dict - args = dict((k, ":%s"%k) for k in params.keys()) + args = dict((k, ":%s" % k) for k in params.keys()) query = convert_unicode(query % args, self.charset) else: # Handle params as sequence args = [(':arg%d' % i) for i in range(len(params))] query = convert_unicode(query % tuple(args), self.charset) return query, self._format_params(params) - + def execute(self, query, params=None): query, params = self._fix_for_params(query, params) self._guess_input_sizes([params]) @@ -825,9 +827,9 @@ class FormatStylePlaceholderCursor(object): # uniform treatment for sequences and iterables params_iter = iter(params) query, firstparams = self._fix_for_params(query, next(params_iter)) - # we build a list of formatted params; as we're going to traverse it + # we build a list of formatted params; as we're going to traverse it # more than once, we can't make it lazy by using a generator - formatted = [firstparams]+[self._format_params(p) for p in params_iter] + formatted = [firstparams] + [self._format_params(p) for p in params_iter] self._guess_input_sizes(formatted) try: return self.cursor.executemany(query, diff --git a/django/db/backends/oracle/client.py b/django/db/backends/oracle/client.py index ccc64ebffc..ac6b79041f 100644 --- a/django/db/backends/oracle/client.py +++ b/django/db/backends/oracle/client.py @@ -3,6 +3,7 @@ import sys from django.db.backends import BaseDatabaseClient + class DatabaseClient(BaseDatabaseClient): executable_name = 'sqlplus' @@ -13,4 +14,3 @@ class DatabaseClient(BaseDatabaseClient): sys.exit(os.system(" ".join(args))) else: os.execvp(self.executable_name, args) - diff --git a/django/db/backends/oracle/compiler.py b/django/db/backends/oracle/compiler.py index cbee27951c..d2d4cd3ac9 100644 --- a/django/db/backends/oracle/compiler.py +++ b/django/db/backends/oracle/compiler.py @@ -60,17 +60,22 @@ class SQLCompiler(compiler.SQLCompiler): class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): pass + class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): pass + class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): pass + class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): pass + class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler): pass + class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler): pass diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py index c7808e4849..55f6ee4d7e 100644 --- a/django/db/backends/oracle/creation.py +++ b/django/db/backends/oracle/creation.py @@ -5,9 +5,11 @@ from django.conf import settings from django.db.backends.creation import BaseDatabaseCreation from django.utils.six.moves import input + TEST_DATABASE_PREFIX = 'test_' PASSWORD = 'Im_a_lumberjack' + class DatabaseCreation(BaseDatabaseCreation): # This dictionary maps Field objects to their associated Oracle column # types, as strings. Column-type strings can contain format strings; they'll @@ -18,30 +20,30 @@ class DatabaseCreation(BaseDatabaseCreation): # output (the "qn_" prefix is stripped before the lookup is performed. data_types = { - 'AutoField': 'NUMBER(11)', - 'BinaryField': 'BLOB', - 'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))', - 'CharField': 'NVARCHAR2(%(max_length)s)', - 'CommaSeparatedIntegerField': 'VARCHAR2(%(max_length)s)', - 'DateField': 'DATE', - 'DateTimeField': 'TIMESTAMP', - 'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)', - 'FileField': 'NVARCHAR2(%(max_length)s)', - 'FilePathField': 'NVARCHAR2(%(max_length)s)', - 'FloatField': 'DOUBLE PRECISION', - 'IntegerField': 'NUMBER(11)', - 'BigIntegerField': 'NUMBER(19)', - 'IPAddressField': 'VARCHAR2(15)', - 'GenericIPAddressField': 'VARCHAR2(39)', - 'NullBooleanField': 'NUMBER(1) CHECK ((%(qn_column)s IN (0,1)) OR (%(qn_column)s IS NULL))', - 'OneToOneField': 'NUMBER(11)', - 'PositiveIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)', - 'PositiveSmallIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)', - 'SlugField': 'NVARCHAR2(%(max_length)s)', - 'SmallIntegerField': 'NUMBER(11)', - 'TextField': 'NCLOB', - 'TimeField': 'TIMESTAMP', - 'URLField': 'VARCHAR2(%(max_length)s)', + 'AutoField': 'NUMBER(11)', + 'BinaryField': 'BLOB', + 'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))', + 'CharField': 'NVARCHAR2(%(max_length)s)', + 'CommaSeparatedIntegerField': 'VARCHAR2(%(max_length)s)', + 'DateField': 'DATE', + 'DateTimeField': 'TIMESTAMP', + 'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)', + 'FileField': 'NVARCHAR2(%(max_length)s)', + 'FilePathField': 'NVARCHAR2(%(max_length)s)', + 'FloatField': 'DOUBLE PRECISION', + 'IntegerField': 'NUMBER(11)', + 'BigIntegerField': 'NUMBER(19)', + 'IPAddressField': 'VARCHAR2(15)', + 'GenericIPAddressField': 'VARCHAR2(39)', + 'NullBooleanField': 'NUMBER(1) CHECK ((%(qn_column)s IN (0,1)) OR (%(qn_column)s IS NULL))', + 'OneToOneField': 'NUMBER(11)', + 'PositiveIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)', + 'PositiveSmallIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)', + 'SlugField': 'NVARCHAR2(%(max_length)s)', + 'SmallIntegerField': 'NUMBER(11)', + 'TextField': 'NCLOB', + 'TimeField': 'TIMESTAMP', + 'URLField': 'VARCHAR2(%(max_length)s)', } def __init__(self, connection): @@ -183,7 +185,7 @@ class DatabaseCreation(BaseDatabaseCreation): statements = [ 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', - ] + ] self._execute_statements(cursor, statements, parameters, verbosity) def _destroy_test_user(self, cursor, parameters, verbosity): diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index 361308a62c..a2fad92509 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -1,10 +1,13 @@ +import re + +import cx_Oracle + from django.db.backends import BaseDatabaseIntrospection, FieldInfo from django.utils.encoding import force_text -import cx_Oracle -import re foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") + class DatabaseIntrospection(BaseDatabaseIntrospection): # Maps type objects to Django Field types. data_types_reverse = { @@ -95,11 +98,11 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): SELECT ccol.column_name, rcol.table_name AS referenced_table, rcol.column_name AS referenced_column FROM user_constraints c JOIN user_cons_columns ccol - ON ccol.constraint_name = c.constraint_name + ON ccol.constraint_name = c.constraint_name JOIN user_cons_columns rcol - ON rcol.constraint_name = c.r_constraint_name - WHERE c.table_name = %s AND c.constraint_type = 'R'""" , [table_name.upper()]) - return [tuple(cell.lower() for cell in row) + ON rcol.constraint_name = c.r_constraint_name + WHERE c.table_name = %s AND c.constraint_type = 'R'""", [table_name.upper()]) + return [tuple(cell.lower() for cell in row) for row in cursor.fetchall()] def get_indexes(self, cursor, table_name): diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 6ed2cfcc7c..e676065578 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -6,7 +6,9 @@ Requires psycopg 2: http://initd.org/projects/psycopg2 import logging import sys -from django.db.backends import * +from django.conf import settings +from django.db.backends import (BaseDatabaseFeatures, BaseDatabaseWrapper, + BaseDatabaseValidation) from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations from django.db.backends.postgresql_psycopg2.client import DatabaseClient from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation @@ -33,11 +35,13 @@ psycopg2.extensions.register_adapter(SafeText, psycopg2.extensions.QuotedString) logger = logging.getLogger('django.db.backends') + def utc_tzinfo_factory(offset): if offset != 0: raise AssertionError("database connection isn't set to UTC") return utc + class DatabaseFeatures(BaseDatabaseFeatures): needs_datetime_string_cast = False can_return_id_from_insert = True @@ -52,6 +56,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_transactions = True can_distinct_on_fields = True + class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'postgresql' operators = { @@ -132,7 +137,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): # Set the time zone in autocommit mode (see #17062) self.set_autocommit(True) self.connection.cursor().execute( - self.ops.set_time_zone_sql(), [tz]) + self.ops.set_time_zone_sql(), [tz] + ) self.connection.set_isolation_level(self.isolation_level) def create_cursor(self): diff --git a/django/db/backends/postgresql_psycopg2/client.py b/django/db/backends/postgresql_psycopg2/client.py index a5c02969ea..23ac9f2975 100644 --- a/django/db/backends/postgresql_psycopg2/client.py +++ b/django/db/backends/postgresql_psycopg2/client.py @@ -3,6 +3,7 @@ import sys from django.db.backends import BaseDatabaseClient + class DatabaseClient(BaseDatabaseClient): executable_name = 'psql' @@ -20,4 +21,3 @@ class DatabaseClient(BaseDatabaseClient): sys.exit(os.system(" ".join(args))) else: os.execvp(self.executable_name, args) - diff --git a/django/db/backends/postgresql_psycopg2/creation.py b/django/db/backends/postgresql_psycopg2/creation.py index d4260e05c4..cbf901555d 100644 --- a/django/db/backends/postgresql_psycopg2/creation.py +++ b/django/db/backends/postgresql_psycopg2/creation.py @@ -8,29 +8,29 @@ class DatabaseCreation(BaseDatabaseCreation): # be interpolated against the values of Field.__dict__ before being output. # If a column type is set to None, it won't be included in the output. data_types = { - 'AutoField': 'serial', - 'BinaryField': 'bytea', - 'BooleanField': 'boolean', - 'CharField': 'varchar(%(max_length)s)', + 'AutoField': 'serial', + 'BinaryField': 'bytea', + 'BooleanField': 'boolean', + 'CharField': 'varchar(%(max_length)s)', 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'timestamp with time zone', - 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'double precision', - 'IntegerField': 'integer', - 'BigIntegerField': 'bigint', - 'IPAddressField': 'inet', + 'DateField': 'date', + 'DateTimeField': 'timestamp with time zone', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FileField': 'varchar(%(max_length)s)', + 'FilePathField': 'varchar(%(max_length)s)', + 'FloatField': 'double precision', + 'IntegerField': 'integer', + 'BigIntegerField': 'bigint', + 'IPAddressField': 'inet', 'GenericIPAddressField': 'inet', - 'NullBooleanField': 'boolean', - 'OneToOneField': 'integer', + 'NullBooleanField': 'boolean', + 'OneToOneField': 'integer', 'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)', 'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)', - 'SlugField': 'varchar(%(max_length)s)', + 'SlugField': 'varchar(%(max_length)s)', 'SmallIntegerField': 'smallint', - 'TextField': 'text', - 'TimeField': 'time', + 'TextField': 'text', + 'TimeField': 'time', } def sql_table_creation_suffix(self): @@ -54,7 +54,7 @@ class DatabaseCreation(BaseDatabaseCreation): def get_index_sql(index_name, opclass=''): return (style.SQL_KEYWORD('CREATE INDEX') + ' ' + - style.SQL_TABLE(qn(truncate_name(index_name,self.connection.ops.max_name_length()))) + ' ' + + style.SQL_TABLE(qn(truncate_name(index_name, self.connection.ops.max_name_length()))) + ' ' + style.SQL_KEYWORD('ON') + ' ' + style.SQL_TABLE(qn(db_table)) + ' ' + "(%s%s)" % (style.SQL_FIELD(qn(f.column)), opclass) + diff --git a/django/db/backends/postgresql_psycopg2/introspection.py b/django/db/backends/postgresql_psycopg2/introspection.py index c334b9d6d0..0ebd3c1ed5 100644 --- a/django/db/backends/postgresql_psycopg2/introspection.py +++ b/django/db/backends/postgresql_psycopg2/introspection.py @@ -25,7 +25,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): 1266: 'TimeField', 1700: 'DecimalField', } - + def get_table_list(self, cursor): "Returns a list of table names in the current database." cursor.execute(""" @@ -47,7 +47,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): WHERE table_name = %s""", [table_name]) null_map = dict(cursor.fetchall()) cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) - return [FieldInfo(*((force_text(line[0]),) + line[1:6] + (null_map[force_text(line[0])]=='YES',))) + return [FieldInfo(*((force_text(line[0]),) + line[1:6] + (null_map[force_text(line[0])] == 'YES',))) for line in cursor.description] def get_relations(self, cursor, table_name): @@ -81,7 +81,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): ON ccu.constraint_catalog = tc.constraint_catalog AND ccu.constraint_schema = tc.constraint_schema AND ccu.constraint_name = tc.constraint_name - WHERE kcu.table_name = %s AND tc.constraint_type = 'FOREIGN KEY'""" , [table_name]) + WHERE kcu.table_name = %s AND tc.constraint_type = 'FOREIGN KEY'""", [table_name]) key_columns.extend(cursor.fetchall()) return key_columns diff --git a/django/db/backends/postgresql_psycopg2/operations.py b/django/db/backends/postgresql_psycopg2/operations.py index c5aab84693..cc78ffe449 100644 --- a/django/db/backends/postgresql_psycopg2/operations.py +++ b/django/db/backends/postgresql_psycopg2/operations.py @@ -135,7 +135,7 @@ class DatabaseOperations(BaseDatabaseOperations): # This will be the case if it's an m2m using an autogenerated # intermediate table (see BaseDatabaseIntrospection.sequence_list) column_name = 'id' - sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % \ + sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (style.SQL_KEYWORD('SELECT'), style.SQL_TABLE(self.quote_name(table_name)), style.SQL_FIELD(column_name)) @@ -161,7 +161,7 @@ class DatabaseOperations(BaseDatabaseOperations): for f in model._meta.local_fields: if isinstance(f, models.AutoField): - output.append("%s setval(pg_get_serial_sequence('%s','%s'), coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + output.append("%s setval(pg_get_serial_sequence('%s','%s'), coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (style.SQL_KEYWORD('SELECT'), style.SQL_TABLE(qn(model._meta.db_table)), style.SQL_FIELD(f.column), @@ -173,7 +173,7 @@ class DatabaseOperations(BaseDatabaseOperations): break # Only one AutoField is allowed per model, so don't bother continuing. for f in model._meta.many_to_many: if not f.rel.through: - output.append("%s setval(pg_get_serial_sequence('%s','%s'), coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + output.append("%s setval(pg_get_serial_sequence('%s','%s'), coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (style.SQL_KEYWORD('SELECT'), style.SQL_TABLE(qn(f.m2m_db_table())), style.SQL_FIELD('id'), diff --git a/django/db/backends/postgresql_psycopg2/version.py b/django/db/backends/postgresql_psycopg2/version.py index 8ef516704e..dae94f2dac 100644 --- a/django/db/backends/postgresql_psycopg2/version.py +++ b/django/db/backends/postgresql_psycopg2/version.py @@ -19,7 +19,8 @@ def _parse_version(text): try: return int(major) * 10000 + int(major2) * 100 + int(minor) except (ValueError, TypeError): - return int(major) * 10000 + int(major2) * 100 + return int(major) * 10000 + int(major2) * 100 + def get_version(connection): """ diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 4345790b06..dc7db2fceb 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -11,8 +11,10 @@ import decimal import warnings import re +from django.conf import settings from django.db import utils -from django.db.backends import * +from django.db.backends import (util, BaseDatabaseFeatures, + BaseDatabaseOperations, BaseDatabaseWrapper, BaseDatabaseValidation) from django.db.backends.sqlite3.client import DatabaseClient from django.db.backends.sqlite3.creation import DatabaseCreation from django.db.backends.sqlite3.introspection import DatabaseIntrospection @@ -42,6 +44,7 @@ except ImportError: DatabaseError = Database.DatabaseError IntegrityError = Database.IntegrityError + def parse_datetime_with_timezone_support(value): dt = parse_datetime(value) # Confirm that dt is naive before overwriting its tzinfo. @@ -49,6 +52,7 @@ def parse_datetime_with_timezone_support(value): dt = dt.replace(tzinfo=timezone.utc) return dt + def adapt_datetime_with_timezone_support(value): # Equivalent to DateTimeField.get_db_prep_value. Used only by raw SQL. if settings.USE_TZ: @@ -61,6 +65,7 @@ def adapt_datetime_with_timezone_support(value): value = value.astimezone(timezone.utc).replace(tzinfo=None) return value.isoformat(str(" ")) + def decoder(conv_func): """ The Python sqlite3 interface returns always byte strings. This function converts the received value to a regular string before @@ -81,6 +86,7 @@ Database.register_adapter(decimal.Decimal, util.rev_typecast_decimal) Database.register_adapter(str, lambda s: s.decode('utf-8')) Database.register_adapter(SafeBytes, lambda s: s.decode('utf-8')) + class DatabaseFeatures(BaseDatabaseFeatures): # SQLite cannot handle us only partially reading from a cursor's result set # and then writing the same rows to the database in another cursor. This @@ -124,6 +130,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): def has_zoneinfo_database(self): return pytz is not None + class DatabaseOperations(BaseDatabaseOperations): def bulk_batch_size(self, fields, objs): """ @@ -272,6 +279,7 @@ class DatabaseOperations(BaseDatabaseOperations): res.extend(["UNION ALL SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1)) return " ".join(res) + class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'sqlite' # SQLite requires LIKE statements to include an ESCAPE clause if the value @@ -426,6 +434,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s') + class SQLiteCursorWrapper(Database.Cursor): """ Django uses "format" style placeholders, but pysqlite2 uses "qmark" style. @@ -445,6 +454,7 @@ class SQLiteCursorWrapper(Database.Cursor): def convert_query(self, query): return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%') + def _sqlite_date_extract(lookup_type, dt): if dt is None: return None @@ -457,6 +467,7 @@ def _sqlite_date_extract(lookup_type, dt): else: return getattr(dt, lookup_type) + def _sqlite_date_trunc(lookup_type, dt): try: dt = util.typecast_timestamp(dt) @@ -469,6 +480,7 @@ def _sqlite_date_trunc(lookup_type, dt): elif lookup_type == 'day': return "%i-%02i-%02i" % (dt.year, dt.month, dt.day) + def _sqlite_datetime_extract(lookup_type, dt, tzname): if dt is None: return None @@ -483,6 +495,7 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname): else: return getattr(dt, lookup_type) + def _sqlite_datetime_trunc(lookup_type, dt, tzname): try: dt = util.typecast_timestamp(dt) @@ -503,6 +516,7 @@ def _sqlite_datetime_trunc(lookup_type, dt, tzname): elif lookup_type == 'second': return "%i-%02i-%02i %02i:%02i:%02i" % (dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second) + def _sqlite_format_dtdelta(dt, conn, days, secs, usecs): try: dt = util.typecast_timestamp(dt) @@ -517,5 +531,6 @@ def _sqlite_format_dtdelta(dt, conn, days, secs, usecs): # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]" return str(dt) + def _sqlite_regexp(re_pattern, re_string): return bool(re.search(re_pattern, force_text(re_string))) if re_string is not None else False diff --git a/django/db/backends/sqlite3/client.py b/django/db/backends/sqlite3/client.py index 5b5b7326f2..6a3ad9e76f 100644 --- a/django/db/backends/sqlite3/client.py +++ b/django/db/backends/sqlite3/client.py @@ -3,6 +3,7 @@ import sys from django.db.backends import BaseDatabaseClient + class DatabaseClient(BaseDatabaseClient): executable_name = 'sqlite3' @@ -13,4 +14,3 @@ class DatabaseClient(BaseDatabaseClient): sys.exit(os.system(" ".join(args))) else: os.execvp(self.executable_name, args) - diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py index a9fb273f7a..435cee3436 100644 --- a/django/db/backends/sqlite3/creation.py +++ b/django/db/backends/sqlite3/creation.py @@ -1,36 +1,38 @@ import os import sys + from django.db.backends.creation import BaseDatabaseCreation from django.utils.six.moves import input + class DatabaseCreation(BaseDatabaseCreation): # SQLite doesn't actually support most of these types, but it "does the right # thing" given more verbose field definitions, so leave them as is so that # schema inspection is more useful. data_types = { - 'AutoField': 'integer', - 'BinaryField': 'BLOB', - 'BooleanField': 'bool', - 'CharField': 'varchar(%(max_length)s)', - 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'datetime', - 'DecimalField': 'decimal', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'real', - 'IntegerField': 'integer', - 'BigIntegerField': 'bigint', - 'IPAddressField': 'char(15)', - 'GenericIPAddressField': 'char(39)', - 'NullBooleanField': 'bool', - 'OneToOneField': 'integer', - 'PositiveIntegerField': 'integer unsigned', - 'PositiveSmallIntegerField': 'smallint unsigned', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallIntegerField': 'smallint', - 'TextField': 'text', - 'TimeField': 'time', + 'AutoField': 'integer', + 'BinaryField': 'BLOB', + 'BooleanField': 'bool', + 'CharField': 'varchar(%(max_length)s)', + 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', + 'DateField': 'date', + 'DateTimeField': 'datetime', + 'DecimalField': 'decimal', + 'FileField': 'varchar(%(max_length)s)', + 'FilePathField': 'varchar(%(max_length)s)', + 'FloatField': 'real', + 'IntegerField': 'integer', + 'BigIntegerField': 'bigint', + 'IPAddressField': 'char(15)', + 'GenericIPAddressField': 'char(39)', + 'NullBooleanField': 'bool', + 'OneToOneField': 'integer', + 'PositiveIntegerField': 'integer unsigned', + 'PositiveSmallIntegerField': 'smallint unsigned', + 'SlugField': 'varchar(%(max_length)s)', + 'SmallIntegerField': 'smallint', + 'TextField': 'text', + 'TimeField': 'time', } def sql_for_pending_references(self, model, style, pending_references): @@ -80,7 +82,6 @@ class DatabaseCreation(BaseDatabaseCreation): SQLite since the databases will be distinct despite having the same TEST_NAME. See http://www.sqlite.org/inmemorydb.html """ - settings_dict = self.connection.settings_dict test_dbname = self._get_test_db_name() sig = [self.connection.settings_dict['NAME']] if test_dbname == ':memory:': diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 431e112e56..385888ed40 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -1,8 +1,11 @@ import re + from django.db.backends import BaseDatabaseIntrospection, FieldInfo + field_size_re = re.compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$') + def get_field_size(name): """ Extract the size number from a "varchar(11)" type name """ m = field_size_re.search(name) @@ -46,6 +49,7 @@ class FlexibleFieldLookupDict(object): return ('CharField', {'max_length': size}) raise KeyError + class DatabaseIntrospection(BaseDatabaseIntrospection): data_types_reverse = FlexibleFieldLookupDict() @@ -76,7 +80,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Schema for this table cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"]) results = cursor.fetchone()[0].strip() - results = results[results.index('(')+1:results.rindex(')')] + results = results[results.index('(') + 1:results.rindex(')')] # Walk through and look for references to other tables. SQLite doesn't # really have enforced references, but since it echoes out the SQL used @@ -96,8 +100,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): result = cursor.fetchall()[0] other_table_results = result[0].strip() li, ri = other_table_results.index('('), other_table_results.rindex(')') - other_table_results = other_table_results[li+1:ri] - + other_table_results = other_table_results[li + 1:ri] for other_index, other_desc in enumerate(other_table_results.split(',')): other_desc = other_desc.strip() @@ -121,7 +124,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Schema for this table cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"]) results = cursor.fetchone()[0].strip() - results = results[results.index('(')+1:results.rindex(')')] + results = results[results.index('(') + 1:results.rindex(')')] # Walk through and look for references to other tables. SQLite doesn't # really have enforced references, but since it echoes out the SQL used @@ -166,7 +169,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): # Don't use PRAGMA because that causes issues with some transactions cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"]) results = cursor.fetchone()[0].strip() - results = results[results.index('(')+1:results.rindex(')')] + results = results[results.index('(') + 1:results.rindex(')')] for field_desc in results.split(','): field_desc = field_desc.strip() m = re.search('"(.*)".*PRIMARY KEY$', field_desc) diff --git a/django/db/backends/util.py b/django/db/backends/util.py index aa2601277a..f8157c50e7 100644 --- a/django/db/backends/util.py +++ b/django/db/backends/util.py @@ -85,20 +85,25 @@ class CursorDebugWrapper(CursorWrapper): def typecast_date(s): return datetime.date(*map(int, s.split('-'))) if s else None # returns None if s is null + def typecast_time(s): # does NOT store time zone information - if not s: return None + if not s: + return None hour, minutes, seconds = s.split(':') if '.' in seconds: # check whether seconds have a fractional part seconds, microseconds = seconds.split('.') else: microseconds = '0' - return datetime.time(int(hour), int(minutes), int(seconds), int(float('.'+microseconds) * 1000000)) + return datetime.time(int(hour), int(minutes), int(seconds), int(float('.' + microseconds) * 1000000)) + def typecast_timestamp(s): # does NOT store time zone information # "2005-07-29 15:48:00.590358-05" # "2005-07-29 09:56:00-05" - if not s: return None - if not ' ' in s: return typecast_date(s) + if not s: + return None + if not ' ' in s: + return typecast_date(s) d, t = s.split() # Extract timezone information, if it exists. Currently we just throw # it away, but in the future we may make use of it. @@ -122,11 +127,13 @@ def typecast_timestamp(s): # does NOT store time zone information int(times[0]), int(times[1]), int(seconds), int((microseconds + '000000')[:6]), tzinfo) + def typecast_decimal(s): if s is None or s == '': return None return decimal.Decimal(s) + ############################################### # Converters from Python to database (string) # ############################################### @@ -136,6 +143,7 @@ def rev_typecast_decimal(d): return None return str(d) + def truncate_name(name, length=None, hash_len=4): """Shortens a string to a repeatable mangled version with the given length. """ @@ -143,7 +151,8 @@ def truncate_name(name, length=None, hash_len=4): return name hsh = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len] - return '%s%s' % (name[:length-hash_len], hsh) + return '%s%s' % (name[:length - hash_len], hsh) + def format_number(value, max_digits, decimal_places): """ diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index b5dd1a58bc..4d310e480b 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -26,6 +26,7 @@ def permalink(func): (viewname, viewargs, viewkwargs) """ from django.core.urlresolvers import reverse + @wraps(func) def inner(*args, **kwargs): bits = func(*args, **kwargs) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index b89db1c563..1db3890204 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -3,6 +3,7 @@ Classes to represent the definitions of aggregate functions. """ from django.db.models.constants import LOOKUP_SEP + def refs_aggregate(lookup_parts, aggregates): """ A little helper method to check if the lookup_parts contains references @@ -15,6 +16,7 @@ def refs_aggregate(lookup_parts, aggregates): return True return False + class Aggregate(object): """ Default Aggregate definition. @@ -58,23 +60,30 @@ class Aggregate(object): aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) query.aggregates[alias] = aggregate + class Avg(Aggregate): name = 'Avg' + class Count(Aggregate): name = 'Count' + class Max(Aggregate): name = 'Max' + class Min(Aggregate): name = 'Min' + class StdDev(Aggregate): name = 'StdDev' + class Sum(Aggregate): name = 'Sum' + class Variance(Aggregate): name = 'Variance' diff --git a/django/db/models/base.py b/django/db/models/base.py index 1238bfb4ce..fbfda616a1 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -226,9 +226,9 @@ class ModelBase(type): # class for field in base._meta.virtual_fields: if base._meta.abstract and field.name in field_names: - raise FieldError('Local field %r in class %r clashes '\ - 'with field of similar name from '\ - 'abstract base class %r' % \ + raise FieldError('Local field %r in class %r clashes ' + 'with field of similar name from ' + 'abstract base class %r' % (field.name, name, base.__name__)) new_class.add_to_class(field.name, copy.deepcopy(field)) @@ -1008,8 +1008,6 @@ def get_absolute_url(opts, func, self, *args, **kwargs): # MISC # ######## -class Empty(object): - pass def simple_class_factory(model, attrs): """ @@ -1017,6 +1015,7 @@ def simple_class_factory(model, attrs): """ return model + def model_unpickle(model_id, attrs, factory): """ Used to unpickle Model subclasses with deferred fields. diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 6e0f3c434e..b4eea5fe71 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -4,6 +4,7 @@ from django.db.models.aggregates import refs_aggregate from django.db.models.constants import LOOKUP_SEP from django.utils import tree + class ExpressionNode(tree.Node): """ Base class for all query expressions. @@ -128,6 +129,7 @@ class ExpressionNode(tree.Node): "Use .bitand() and .bitor() for bitwise logical operations." ) + class F(ExpressionNode): """ An expression representing the value of the given field. @@ -147,6 +149,7 @@ class F(ExpressionNode): def evaluate(self, evaluator, qn, connection): return evaluator.evaluate_leaf(self, qn, connection) + class DateModifierNode(ExpressionNode): """ Node that implements the following syntax: diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index a12b150cf6..02d887b7f8 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -25,9 +25,11 @@ from django.utils.encoding import smart_text, force_text, force_bytes from django.utils.ipv6 import clean_ipv6_address from django.utils import six + class Empty(object): pass + class NOT_PROVIDED: pass @@ -35,12 +37,15 @@ class NOT_PROVIDED: # of most "choices" lists. BLANK_CHOICE_DASH = [("", "---------")] + def _load_field(app_label, model_name, field_name): return get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0] + class FieldDoesNotExist(Exception): pass + # A guide to Field parameters: # # * name: The name of the field specifed in the model. @@ -61,6 +66,7 @@ def _empty(of_cls): new.__class__ = of_cls return new + @total_ordering class Field(object): """Base class for all field types""" @@ -444,12 +450,12 @@ class Field(object): if hasattr(value, '_prepare'): return value._prepare() - if lookup_type in ( - 'iexact', 'contains', 'icontains', - 'startswith', 'istartswith', 'endswith', 'iendswith', - 'month', 'day', 'week_day', 'hour', 'minute', 'second', - 'isnull', 'search', 'regex', 'iregex', - ): + if lookup_type in { + 'iexact', 'contains', 'icontains', + 'startswith', 'istartswith', 'endswith', 'iendswith', + 'month', 'day', 'week_day', 'hour', 'minute', 'second', + 'isnull', 'search', 'regex', 'iregex', + }: return value elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): return self.get_prep_value(value) @@ -593,7 +599,7 @@ class Field(object): if isinstance(value, (list, tuple)): flat.extend(value) else: - flat.append((choice,value)) + flat.append((choice, value)) return flat flatchoices = property(_get_flatchoices) @@ -712,6 +718,7 @@ class AutoField(Field): def formfield(self, **kwargs): return None + class BooleanField(Field): empty_strings_allowed = False default_error_messages = { @@ -766,13 +773,13 @@ class BooleanField(Field): if self.choices: include_blank = (self.null or not (self.has_default() or 'initial' in kwargs)) - defaults = {'choices': self.get_choices( - include_blank=include_blank)} + defaults = {'choices': self.get_choices(include_blank=include_blank)} else: defaults = {'form_class': forms.BooleanField} defaults.update(kwargs) return super(BooleanField, self).formfield(**defaults) + class CharField(Field): description = _("String (up to %(max_length)s)") @@ -799,6 +806,7 @@ class CharField(Field): defaults.update(kwargs) return super(CharField, self).formfield(**defaults) + # TODO: Maybe move this into contrib, because it's specialized. class CommaSeparatedIntegerField(CharField): default_validators = [validators.validate_comma_separated_integer_list] @@ -813,6 +821,7 @@ class CommaSeparatedIntegerField(CharField): defaults.update(kwargs) return super(CommaSeparatedIntegerField, self).formfield(**defaults) + class DateField(Field): empty_strings_allowed = False default_error_messages = { @@ -885,7 +894,7 @@ class DateField(Field): return super(DateField, self).pre_save(model_instance, add) def contribute_to_class(self, cls, name): - super(DateField,self).contribute_to_class(cls, name) + super(DateField, self).contribute_to_class(cls, name) if not self.null: setattr(cls, 'get_next_by_%s' % self.name, curry(cls._get_next_or_previous_by_FIELD, field=self, @@ -919,6 +928,7 @@ class DateField(Field): defaults.update(kwargs) return super(DateField, self).formfield(**defaults) + class DateTimeField(DateField): empty_strings_allowed = False default_error_messages = { @@ -1025,6 +1035,7 @@ class DateTimeField(DateField): defaults.update(kwargs) return super(DateTimeField, self).formfield(**defaults) + class DecimalField(Field): empty_strings_allowed = False default_error_messages = { @@ -1096,6 +1107,7 @@ class DecimalField(Field): defaults.update(kwargs) return super(DecimalField, self).formfield(**defaults) + class EmailField(CharField): default_validators = [validators.validate_email] description = _("Email address") @@ -1122,13 +1134,14 @@ class EmailField(CharField): defaults.update(kwargs) return super(EmailField, self).formfield(**defaults) + class FilePathField(Field): description = _("File path") def __init__(self, verbose_name=None, name=None, path='', match=None, recursive=False, allow_files=True, allow_folders=False, **kwargs): self.path, self.match, self.recursive = path, match, recursive - self.allow_files, self.allow_folders = allow_files, allow_folders + self.allow_files, self.allow_folders = allow_files, allow_folders kwargs['max_length'] = kwargs.get('max_length', 100) Field.__init__(self, verbose_name, name, **kwargs) @@ -1163,6 +1176,7 @@ class FilePathField(Field): def get_internal_type(self): return "FilePathField" + class FloatField(Field): empty_strings_allowed = False default_error_messages = { @@ -1195,6 +1209,7 @@ class FloatField(Field): defaults.update(kwargs) return super(FloatField, self).formfield(**defaults) + class IntegerField(Field): empty_strings_allowed = False default_error_messages = { @@ -1233,6 +1248,7 @@ class IntegerField(Field): defaults.update(kwargs) return super(IntegerField, self).formfield(**defaults) + class BigIntegerField(IntegerField): empty_strings_allowed = False description = _("Big (8 byte) integer") @@ -1247,6 +1263,7 @@ class BigIntegerField(IntegerField): defaults.update(kwargs) return super(BigIntegerField, self).formfield(**defaults) + class IPAddressField(Field): empty_strings_allowed = False description = _("IPv4 address") @@ -1268,6 +1285,7 @@ class IPAddressField(Field): defaults.update(kwargs) return super(IPAddressField, self).formfield(**defaults) + class GenericIPAddressField(Field): empty_strings_allowed = True description = _("IP address") @@ -1383,6 +1401,7 @@ class NullBooleanField(Field): defaults.update(kwargs) return super(NullBooleanField, self).formfield(**defaults) + class PositiveIntegerField(IntegerField): description = _("Positive integer") @@ -1394,6 +1413,7 @@ class PositiveIntegerField(IntegerField): defaults.update(kwargs) return super(PositiveIntegerField, self).formfield(**defaults) + class PositiveSmallIntegerField(IntegerField): description = _("Positive small integer") @@ -1405,6 +1425,7 @@ class PositiveSmallIntegerField(IntegerField): defaults.update(kwargs) return super(PositiveSmallIntegerField, self).formfield(**defaults) + class SlugField(CharField): default_validators = [validators.validate_slug] description = _("Slug (up to %(max_length)s)") @@ -1434,12 +1455,14 @@ class SlugField(CharField): defaults.update(kwargs) return super(SlugField, self).formfield(**defaults) + class SmallIntegerField(IntegerField): description = _("Small integer") def get_internal_type(self): return "SmallIntegerField" + class TextField(Field): description = _("Text") @@ -1456,6 +1479,7 @@ class TextField(Field): defaults.update(kwargs) return super(TextField, self).formfield(**defaults) + class TimeField(Field): empty_strings_allowed = False default_error_messages = { @@ -1539,6 +1563,7 @@ class TimeField(Field): defaults.update(kwargs) return super(TimeField, self).formfield(**defaults) + class URLField(CharField): default_validators = [validators.URLValidator()] description = _("URL") @@ -1562,6 +1587,7 @@ class URLField(CharField): defaults.update(kwargs) return super(URLField, self).formfield(**defaults) + class BinaryField(Field): description = _("Raw binary data") empty_values = [None, b''] diff --git a/django/db/models/fields/files.py b/django/db/models/fields/files.py index 7663c8ab90..311f74a905 100644 --- a/django/db/models/fields/files.py +++ b/django/db/models/fields/files.py @@ -11,6 +11,7 @@ from django.utils.encoding import force_str, force_text from django.utils import six from django.utils.translation import ugettext_lazy as _ + class FieldFile(File): def __init__(self, instance, field, name): super(FieldFile, self).__init__(None, name) @@ -135,6 +136,7 @@ class FieldFile(File): # be restored later, by FileDescriptor below. return {'name': self.name, 'closed': False, '_committed': True, '_file': None} + class FileDescriptor(object): """ The descriptor for the file attribute on the model instance. Returns a @@ -205,6 +207,7 @@ class FileDescriptor(object): def __set__(self, instance, value): instance.__dict__[self.field.name] = value + class FileField(Field): # The class to wrap instance attributes in. Accessing the file object off @@ -300,6 +303,7 @@ class FileField(Field): defaults.update(kwargs) return super(FileField, self).formfield(**defaults) + class ImageFileDescriptor(FileDescriptor): """ Just like the FileDescriptor, but for ImageFields. The only difference is @@ -321,14 +325,15 @@ class ImageFileDescriptor(FileDescriptor): if previous_file is not None: self.field.update_dimension_fields(instance, force=True) -class ImageFieldFile(ImageFile, FieldFile): +class ImageFieldFile(ImageFile, FieldFile): def delete(self, save=True): # Clear the image dimensions cache if hasattr(self, '_dimensions_cache'): del self._dimensions_cache super(ImageFieldFile, self).delete(save) + class ImageField(FileField): attr_class = ImageFieldFile descriptor_class = ImageFileDescriptor diff --git a/django/db/models/fields/proxy.py b/django/db/models/fields/proxy.py index c0cc873f4c..29c782b59a 100644 --- a/django/db/models/fields/proxy.py +++ b/django/db/models/fields/proxy.py @@ -5,6 +5,7 @@ have the same attributes as fields sometimes (avoids a lot of special casing). from django.db.models import fields + class OrderWrt(fields.IntegerField): """ A proxy for the _order database field that is used when diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 1e7e73ddbe..1034d8b2ac 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -213,7 +213,7 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri # If null=True, we can assign null here, but otherwise the value needs # to be an instance of the related class. - if value is None and self.related.field.null == False: + if value is None and self.related.field.null is False: raise ValueError('Cannot assign None: "%s.%s" does not allow null values.' % (instance._meta.object_name, self.related.get_accessor_name())) elif value is not None and not isinstance(value, self.related.model): @@ -312,7 +312,7 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec def __set__(self, instance, value): # If null=True, we can assign null here, but otherwise the value needs # to be an instance of the related class. - if value is None and self.field.null == False: + if value is None and self.field.null is False: raise ValueError('Cannot assign None: "%s.%s" does not allow null values.' % (instance._meta.object_name, self.field.name)) elif value is not None and not isinstance(value, self.field.rel.to): @@ -397,7 +397,7 @@ class ForeignRelatedObjectsDescriptor(object): def __init__(self, instance): super(RelatedManager, self).__init__() self.instance = instance - self.core_filters= {'%s__exact' % rel_field.name: instance} + self.core_filters = {'%s__exact' % rel_field.name: instance} self.model = rel_model def get_queryset(self): @@ -512,7 +512,6 @@ def create_many_related_manager(superclass, rel): "a many-to-many relationship can be used." % instance.__class__.__name__) - def _get_fk_val(self, obj, field_name): """ Returns the correct value for this relationship's foreign key. This @@ -823,6 +822,7 @@ class ReverseManyRelatedObjectsDescriptor(object): manager.clear() manager.add(*value) + class ForeignObjectRel(object): def __init__(self, field, to, related_name=None, limit_choices_to=None, parent_link=False, on_delete=None, related_query_name=None): @@ -860,6 +860,7 @@ class ForeignObjectRel(object): # example custom multicolumn joins currently have no remote field). self.field_name = None + class ManyToOneRel(ForeignObjectRel): def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None, parent_link=False, on_delete=None, related_query_name=None): @@ -1125,7 +1126,7 @@ class ForeignKey(ForeignObject): def __init__(self, to, to_field=None, rel_class=ManyToOneRel, db_constraint=True, **kwargs): try: - to_name = to._meta.object_name.lower() + to._meta.object_name.lower() except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT assert isinstance(to, six.string_types), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT) else: @@ -1162,7 +1163,6 @@ class ForeignKey(ForeignObject): if self.rel.on_delete is not CASCADE: kwargs['on_delete'] = self.rel.on_delete # Rel needs more work. - rel = self.rel if self.rel.field_name: kwargs['to_field'] = self.rel.field_name if isinstance(self.rel.to, six.string_types): @@ -1193,8 +1193,8 @@ class ForeignKey(ForeignObject): using = router.db_for_read(model_instance.__class__, instance=model_instance) qs = self.rel.to._default_manager.using(using).filter( - **{self.rel.field_name: value} - ) + **{self.rel.field_name: value} + ) qs = qs.complex_filter(self.rel.limit_choices_to) if not qs.exists(): raise exceptions.ValidationError( @@ -1222,7 +1222,7 @@ class ForeignKey(ForeignObject): return field_default def get_db_prep_save(self, value, connection): - if value == '' or value == None: + if value == '' or value is None: return None else: return self.related_field.get_db_prep_save(value, @@ -1389,7 +1389,6 @@ class ManyToManyField(RelatedField): if "help_text" in kwargs: del kwargs['help_text'] # Rel needs more work. - rel = self.rel if isinstance(self.rel.to, six.string_types): kwargs['to'] = self.rel.to else: diff --git a/django/db/models/fields/subclassing.py b/django/db/models/fields/subclassing.py index e6153aefe0..591adb7121 100644 --- a/django/db/models/fields/subclassing.py +++ b/django/db/models/fields/subclassing.py @@ -7,6 +7,7 @@ to_python() and the other necessary methods and everything will work seamlessly. """ + class SubfieldBase(type): """ A metaclass for custom Field subclasses. This ensures the model's attribute @@ -19,6 +20,7 @@ class SubfieldBase(type): ) return new_class + class Creator(object): """ A placeholder class that provides a way to set the attribute on the model. @@ -34,6 +36,7 @@ class Creator(object): def __set__(self, obj, value): obj.__dict__[self.field.name] = self.field.to_python(value) + def make_contrib(superclass, func=None): """ Returns a suitable contribute_to_class() method for the Field subclass. diff --git a/django/db/models/loading.py b/django/db/models/loading.py index 40aa983c25..fb61098bfa 100644 --- a/django/db/models/loading.py +++ b/django/db/models/loading.py @@ -15,9 +15,11 @@ import os __all__ = ('get_apps', 'get_app', 'get_models', 'get_model', 'register_models', 'load_app', 'app_cache_ready') + class UnavailableApp(Exception): pass + class AppCache(object): """ A cache that stores installed applications and their models. Used to diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 6817c9c8ee..065718249e 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -6,6 +6,7 @@ from django.db.models.fields import FieldDoesNotExist from django.utils import six from django.utils.deprecation import RenameMethodsBase + def ensure_default_manager(sender, **kwargs): """ Ensures that a Model subclass contains a default manager and sets the @@ -245,7 +246,7 @@ class ManagerDescriptor(object): self.manager = manager def __get__(self, instance, type=None): - if instance != None: + if instance is not None: raise AttributeError("Manager isn't accessible via %s instances" % type.__name__) return self.manager diff --git a/django/db/models/query.py b/django/db/models/query.py index 2b8ce65e6f..8baef82bc7 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -351,7 +351,7 @@ class QuerySet(object): if objs_with_pk: self._batched_insert(objs_with_pk, fields, batch_size) if objs_without_pk: - fields= [f for f in fields if not isinstance(f, AutoField)] + fields = [f for f in fields if not isinstance(f, AutoField)] self._batched_insert(objs_without_pk, fields, batch_size) return objs @@ -818,7 +818,7 @@ class QuerySet(object): return ops = connections[self.db].ops batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) - for batch in [objs[i:i+batch_size] + for batch in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: self.model._base_manager._insert(batch, fields=fields, using=self.db) @@ -899,10 +899,12 @@ class QuerySet(object): # empty" result. value_annotation = True + class InstanceCheckMeta(type): def __instancecheck__(self, instance): return instance.query.is_empty() + class EmptyQuerySet(six.with_metaclass(InstanceCheckMeta)): """ Marker class usable for checking if a queryset is empty by .none(): @@ -912,6 +914,7 @@ class EmptyQuerySet(six.with_metaclass(InstanceCheckMeta)): def __init__(self, *args, **kwargs): raise TypeError("EmptyQuerySet can't be instantiated") + class ValuesQuerySet(QuerySet): def __init__(self, *args, **kwargs): super(ValuesQuerySet, self).__init__(*args, **kwargs) @@ -1240,7 +1243,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, only_load.get(o.model), reverse=True): next = requested[o.field.related_query_name()] parent = klass if issubclass(o.model, klass) else None - klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, + klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth + 1, requested=next, only_load=only_load, from_parent=parent) reverse_related_fields.append((o.field, klass_info)) if field_names: @@ -1251,7 +1254,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx -def get_cached_row(row, index_start, using, klass_info, offset=0, +def get_cached_row(row, index_start, using, klass_info, offset=0, parent_data=()): """ Helper function that recursively returns an object with the specified @@ -1276,11 +1279,10 @@ def get_cached_row(row, index_start, using, klass_info, offset=0, return None klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info - - fields = row[index_start : index_start + field_count] + fields = row[index_start:index_start + field_count] # If the pk column is None (or the Oracle equivalent ''), then the related # object must be non-existent - set the relation to None. - if fields[pk_idx] == None or fields[pk_idx] == '': + if fields[pk_idx] is None or fields[pk_idx] == '': obj = None elif field_names: fields = list(fields) @@ -1510,8 +1512,6 @@ def prefetch_related_objects(result_cache, related_lookups): if len(result_cache) == 0: return # nothing to do - model = result_cache[0].__class__ - # We need to be able to dynamically add to the list of prefetch_related # lookups that we look up (see below). So we need some book keeping to # ensure we don't do duplicate work. @@ -1538,7 +1538,7 @@ def prefetch_related_objects(result_cache, related_lookups): if len(obj_list) == 0: break - current_lookup = LOOKUP_SEP.join(attrs[0:level+1]) + current_lookup = LOOKUP_SEP.join(attrs[:level + 1]) if current_lookup in done_queries: # Skip any prefetching, and any object preparation obj_list = done_queries[current_lookup] diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index ee7a56a26c..2a92978beb 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -30,6 +30,7 @@ class QueryWrapper(object): def as_sql(self, qn=None, connection=None): return self.data + class Q(tree.Node): """ Encapsulates filters as objects that can then be combined logically (using @@ -74,6 +75,7 @@ class Q(tree.Node): clone.children.append(child) return clone + class DeferredAttribute(object): """ A wrapper for a deferred-loading field. When the value is read from this @@ -99,8 +101,7 @@ class DeferredAttribute(object): try: f = opts.get_field_by_name(self.field_name)[0] except FieldDoesNotExist: - f = [f for f in opts.fields - if f.attname == self.field_name][0] + f = [f for f in opts.fields if f.attname == self.field_name][0] name = f.name # Let's see if the field is part of the parent chain. If so we # might be able to reuse the already loaded value. Refs #18343. @@ -174,6 +175,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa return False return True + # This function is needed because data descriptors must be defined on a class # object, not an instance, to have any effect. diff --git a/django/db/models/related.py b/django/db/models/related.py index 4b00dd343b..ba2a90c545 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -10,6 +10,7 @@ PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field ' 'm2m direct') + class RelatedObject(object): def __init__(self, parent_model, model, field): self.parent_model = parent_model diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 2bd2b2f76f..9fc5fe8a5b 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -9,6 +9,7 @@ from django.db.models.fields import IntegerField, FloatField ordinal_aggregate_field = IntegerField() computed_aggregate_field = FloatField() + class Aggregate(object): """ Default SQL Aggregate. @@ -93,6 +94,7 @@ class Avg(Aggregate): is_computed = True sql_function = 'AVG' + class Count(Aggregate): is_ordinal = True sql_function = 'COUNT' @@ -101,12 +103,15 @@ class Count(Aggregate): def __init__(self, col, distinct=False, **extra): super(Count, self).__init__(col, distinct='DISTINCT ' if distinct else '', **extra) + class Max(Aggregate): sql_function = 'MAX' + class Min(Aggregate): sql_function = 'MIN' + class StdDev(Aggregate): is_computed = True @@ -114,9 +119,11 @@ class StdDev(Aggregate): super(StdDev, self).__init__(col, **extra) self.sql_function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' + class Sum(Aggregate): sql_function = 'SUM' + class Variance(Aggregate): is_computed = True diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index f70750abed..0b12bd1552 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -688,7 +688,7 @@ class SQLCompiler(object): # the relation, which is always nullable. new_nullable = True table = model._meta.db_table - self.fill_related_selections(model._meta, table, cur_depth+1, + self.fill_related_selections(model._meta, table, cur_depth + 1, next, restricted, new_nullable) def deferred_to_columns(self): @@ -915,6 +915,7 @@ class SQLDeleteCompiler(SQLCompiler): result.append('WHERE %s' % where) return ' '.join(result), tuple(params) + class SQLUpdateCompiler(SQLCompiler): def as_sql(self): """ @@ -1029,6 +1030,7 @@ class SQLUpdateCompiler(SQLCompiler): for alias in self.query.tables[1:]: self.query.alias_refcount[alias] = 0 + class SQLAggregateCompiler(SQLCompiler): def as_sql(self, qn=None): """ @@ -1050,6 +1052,7 @@ class SQLAggregateCompiler(SQLCompiler): params = params + self.query.sub_params return sql, params + class SQLDateCompiler(SQLCompiler): def results_iter(self): """ @@ -1075,6 +1078,7 @@ class SQLDateCompiler(SQLCompiler): date = date.date() yield date + class SQLDateTimeCompiler(SQLCompiler): def results_iter(self): """ @@ -1107,6 +1111,7 @@ class SQLDateTimeCompiler(SQLCompiler): datetime = timezone.make_aware(datetime, self.query.tzinfo) yield datetime + def order_modified_iter(cursor, trim, sentinel): """ Yields blocks of rows from a cursor. We use this iterator in the special diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index daaabbe6da..76b3db5b35 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -3,9 +3,11 @@ Useful auxilliary data structures for query construction. Not useful outside the SQL domain. """ + class EmptyResultSet(Exception): pass + class MultiJoin(Exception): """ Used by join construction code to indicate the point at which a @@ -17,12 +19,10 @@ class MultiJoin(Exception): # The path travelled, this includes the path to the multijoin. self.names_with_path = path_with_names + class Empty(object): pass -class RawValue(object): - def __init__(self, value): - self.value = value class Date(object): """ @@ -42,6 +42,7 @@ class Date(object): col = self.col return connection.ops.date_trunc_sql(self.lookup_type, col), [] + class DateTime(object): """ Add a datetime selection column. diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 62adf79d87..f9a8929974 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -1,7 +1,9 @@ +import copy + from django.core.exceptions import FieldError from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import FieldDoesNotExist -import copy + class SQLEvaluator(object): def __init__(self, expression, query, allow_joins=True, reuse=None): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 75e8e7540d..110925114c 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -615,7 +615,6 @@ class Query(object): for model, values in six.iteritems(seen): callback(target, model, values) - def deferred_to_columns_cb(self, target, model, fields): """ Callback used by deferred_to_columns(). The "target" parameter should @@ -627,7 +626,6 @@ class Query(object): for field in fields: target[table].add(field.column) - def table_alias(self, table_name, create=False): """ Returns a table alias for the given table_name and whether this is a @@ -955,7 +953,6 @@ class Query(object): self.unref_alias(alias) self.included_inherited_models = {} - def add_aggregate(self, aggregate, model, alias, is_summary): """ Adds a single aggregate expression to the Query @@ -1780,7 +1777,7 @@ class Query(object): return self._aggregate_select_cache elif self.aggregate_select_mask is not None: self._aggregate_select_cache = SortedDict([ - (k,v) for k,v in self.aggregates.items() + (k, v) for k, v in self.aggregates.items() if k in self.aggregate_select_mask ]) return self._aggregate_select_cache @@ -1793,7 +1790,7 @@ class Query(object): return self._extra_select_cache elif self.extra_select_mask is not None: self._extra_select_cache = SortedDict([ - (k,v) for k,v in self.extra.items() + (k, v) for k, v in self.extra.items() if k in self.extra_select_mask ]) return self._extra_select_cache @@ -1876,6 +1873,7 @@ class Query(object): else: return field.null + def get_order_dir(field, default='ASC'): """ Returns the field name and direction for an order specification. For @@ -1900,6 +1898,7 @@ def add_to_dict(data, key, value): else: data[key] = set([value]) + def is_reverse_o2o(field): """ A little helper to check if the given field is reverse-o2o. The field is @@ -1907,6 +1906,7 @@ def is_reverse_o2o(field): """ return not hasattr(field, 'rel') and field.field.unique + def alias_diff(refcounts_before, refcounts_after): """ Given the before and after copies of refcounts works out which aliases diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 20770162c9..6aab02bd9a 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -7,7 +7,7 @@ from django.core.exceptions import FieldError from django.db import connections from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist -from django.db.models.sql.constants import * +from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query from django.db.models.sql.where import AND, Constraint @@ -20,6 +20,7 @@ from django.utils import timezone __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 'DateTimeQuery', 'AggregateQuery'] + class DeleteQuery(Query): """ Delete queries are done through this class, since they are more constrained @@ -77,7 +78,9 @@ class DeleteQuery(Query): return else: innerq.clear_select_clause() - innerq.select = [SelectInfo((self.get_initial_alias(), pk.column), None)] + innerq.select = [ + SelectInfo((self.get_initial_alias(), pk.column), None) + ] values = innerq where = self.where_class() where.add((Constraint(None, pk.column, pk), 'in', values), AND) @@ -178,6 +181,7 @@ class UpdateQuery(Query): result.append(query) return result + class InsertQuery(Query): compiler = 'SQLInsertCompiler' @@ -215,6 +219,7 @@ class InsertQuery(Query): self.objs = objs self.raw = raw + class DateQuery(Query): """ A DateQuery is a normal query, except that it specifically selects a single @@ -260,6 +265,7 @@ class DateQuery(Query): def _get_select(self, col, lookup_type): return Date(col, lookup_type) + class DateTimeQuery(DateQuery): """ A DateTimeQuery is like a DateQuery but for a datetime field. If time zone @@ -280,6 +286,7 @@ class DateTimeQuery(DateQuery): tzname = timezone._get_timezone_name(self.tzinfo) return DateTime(col, lookup_type, tzname) + class AggregateQuery(Query): """ An AggregateQuery takes another query as a parameter to the FROM diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 2a342d417a..024b995c99 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -16,10 +16,12 @@ from django.utils.six.moves import xrange from django.utils import timezone from django.utils import tree + # Connection types AND = 'AND' OR = 'OR' + class EmptyShortCircuit(Exception): """ Internal exception used to indicate that a "matches nothing" node should be @@ -27,6 +29,7 @@ class EmptyShortCircuit(Exception): """ pass + class WhereNode(tree.Node): """ Used to represent the SQL where-clause. @@ -175,7 +178,7 @@ class WhereNode(tree.Node): """ lvalue, lookup_type, value_annotation, params_or_value = child field_internal_type = lvalue.field.get_internal_type() if lvalue.field else None - + if isinstance(lvalue, Constraint): try: lvalue, params = lvalue.process(lookup_type, params_or_value, connection) @@ -304,14 +307,15 @@ class WhereNode(tree.Node): clone.children.append(child) return clone -class EmptyWhere(WhereNode): +class EmptyWhere(WhereNode): def add(self, data, connector): return def as_sql(self, qn=None, connection=None): raise EmptyResultSet + class EverythingNode(object): """ A node that matches everything. @@ -385,6 +389,7 @@ class Constraint(object): new.alias, new.col, new.field = change_map[self.alias], self.col, self.field return new + class SubqueryConstraint(object): def __init__(self, alias, columns, targets, query_object): self.alias = alias diff --git a/django/db/transaction.py b/django/db/transaction.py index 96be981e7b..bb2dbb8e9e 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -27,6 +27,7 @@ class TransactionManagementError(Exception): """ pass + ################ # Private APIs # ################ @@ -40,6 +41,7 @@ def get_connection(using=None): using = DEFAULT_DB_ALIAS return connections[using] + ########################### # Deprecated private APIs # ########################### @@ -56,6 +58,7 @@ def abort(using=None): """ get_connection(using).abort() + def enter_transaction_management(managed=True, using=None, forced=False): """ Enters transaction management for a running thread. It must be balanced with @@ -68,6 +71,7 @@ def enter_transaction_management(managed=True, using=None, forced=False): """ get_connection(using).enter_transaction_management(managed, forced) + def leave_transaction_management(using=None): """ Leaves transaction management for a running thread. A dirty flag is carried @@ -76,6 +80,7 @@ def leave_transaction_management(using=None): """ get_connection(using).leave_transaction_management() + def is_dirty(using=None): """ Returns True if the current transaction requires a commit for changes to @@ -83,6 +88,7 @@ def is_dirty(using=None): """ return get_connection(using).is_dirty() + def set_dirty(using=None): """ Sets a dirty flag for the current thread and code streak. This can be used @@ -91,6 +97,7 @@ def set_dirty(using=None): """ get_connection(using).set_dirty() + def set_clean(using=None): """ Resets a dirty flag for the current thread and code streak. This can be used @@ -99,22 +106,27 @@ def set_clean(using=None): """ get_connection(using).set_clean() + def is_managed(using=None): warnings.warn("'is_managed' is deprecated.", DeprecationWarning, stacklevel=2) + def managed(flag=True, using=None): warnings.warn("'managed' no longer serves a purpose.", DeprecationWarning, stacklevel=2) + def commit_unless_managed(using=None): warnings.warn("'commit_unless_managed' is now a no-op.", DeprecationWarning, stacklevel=2) + def rollback_unless_managed(using=None): warnings.warn("'rollback_unless_managed' is now a no-op.", DeprecationWarning, stacklevel=2) + ############### # Public APIs # ############### @@ -125,24 +137,28 @@ def get_autocommit(using=None): """ return get_connection(using).get_autocommit() + def set_autocommit(autocommit, using=None): """ Set the autocommit status of the connection. """ return get_connection(using).set_autocommit(autocommit) + def commit(using=None): """ Commits a transaction and resets the dirty flag. """ get_connection(using).commit() + def rollback(using=None): """ Rolls back a transaction and resets the dirty flag. """ get_connection(using).rollback() + def savepoint(using=None): """ Creates a savepoint (if supported and required by the backend) inside the @@ -151,6 +167,7 @@ def savepoint(using=None): """ return get_connection(using).savepoint() + def savepoint_rollback(sid, using=None): """ Rolls back the most recent savepoint (if one exists). Does nothing if @@ -158,6 +175,7 @@ def savepoint_rollback(sid, using=None): """ get_connection(using).savepoint_rollback(sid) + def savepoint_commit(sid, using=None): """ Commits the most recent savepoint (if one exists). Does nothing if @@ -165,18 +183,21 @@ def savepoint_commit(sid, using=None): """ get_connection(using).savepoint_commit(sid) + def clean_savepoints(using=None): """ Resets the counter used to generate unique savepoint ids in this thread. """ get_connection(using).clean_savepoints() + def get_rollback(using=None): """ Gets the "needs rollback" flag -- for *advanced use* only. """ return get_connection(using).get_rollback() + def set_rollback(rollback, using=None): """ Sets or unsets the "needs rollback" flag -- for *advanced use* only. @@ -191,6 +212,7 @@ def set_rollback(rollback, using=None): """ return get_connection(using).set_rollback(rollback) + ################################# # Decorators / context managers # ################################# @@ -398,6 +420,7 @@ class Transaction(object): return func(*args, **kwargs) return inner + def _transaction_func(entering, exiting, using): """ Takes 3 things, an entering function (what to do to start this block of @@ -436,6 +459,7 @@ def autocommit(using=None): return _transaction_func(entering, exiting, using) + def commit_on_success(using=None): """ This decorator activates commit on response. This way, if the view function @@ -466,6 +490,7 @@ def commit_on_success(using=None): return _transaction_func(entering, exiting, using) + def commit_manually(using=None): """ Decorator that activates manual transaction control. It just disables @@ -484,6 +509,7 @@ def commit_manually(using=None): return _transaction_func(entering, exiting, using) + def commit_on_success_unless_managed(using=None, savepoint=False): """ Transitory API to preserve backwards-compatibility while refactoring. |
