diff options
-rw-r--r-- | doc/build/changelog/changelog_10.rst | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/firebird/base.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 82 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 76 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 62 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 68 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/sybase/base.py | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 135 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/util/__init__.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 27 | ||||
-rw-r--r-- | test/sql/test_types.py | 63 |
13 files changed, 373 insertions, 240 deletions
diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst index 5d8bb7b68..089c9fafb 100644 --- a/doc/build/changelog/changelog_10.rst +++ b/doc/build/changelog/changelog_10.rst @@ -23,6 +23,22 @@ on compatibility concerns, see :doc:`/changelog/migration_10`. .. change:: + :tags: enhancement, sql + :tickets: 3074 + + Custom dialects that implement :class:`.GenericTypeCompiler` can + now be constructed such that the visit methods receive an indication + of the owning expression object, if any. Any visit method that + accepts keyword arguments (e.g. ``**kw``) will in most cases + receive a keyword argument ``type_expression``, referring to the + expression object that the type is contained within. For columns + in DDL, the dialect's compiler class may need to alter its + ``get_column_specification()`` method to support this as well. + The ``UserDefinedType.get_col_spec()`` method will also receive + ``type_expression`` if it provides ``**kw`` in its argument + signature. + + .. change:: :tags: bug, sql :tickets: 3288 diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 36229a105..74e8abfc2 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -180,16 +180,16 @@ ischema_names = { # _FBDate, etc. as bind/result functionality is required) class FBTypeCompiler(compiler.GenericTypeCompiler): - def visit_boolean(self, type_): - return self.visit_SMALLINT(type_) + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_datetime(self, type_): - return self.visit_TIMESTAMP(type_) + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return "BLOB SUB_TYPE 1" - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): return "BLOB SUB_TYPE 0" def _extend_string(self, type_, basic): @@ -199,16 +199,16 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): else: return '%s CHARACTER SET %s' % (basic, charset) - def visit_CHAR(self, type_): - basic = super(FBTypeCompiler, self).visit_CHAR(type_) + def visit_CHAR(self, type_, **kw): + basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw) return self._extend_string(type_, basic) - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): if not type_.length: raise exc.CompileError( "VARCHAR requires a length on dialect %s" % self.dialect.name) - basic = super(FBTypeCompiler, self).visit_VARCHAR(type_) + basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw) return self._extend_string(type_, basic) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 5d84975c0..92d7e4ab3 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -694,7 +694,6 @@ ischema_names = { class MSTypeCompiler(compiler.GenericTypeCompiler): - def _extend(self, spec, type_, length=None): """Extend a string-type declaration with standard SQL COLLATE annotations. @@ -715,115 +714,115 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return ' '.join([c for c in (spec, collation) if c is not None]) - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision is None: return "FLOAT" else: return "FLOAT(%(precision)s)" % {'precision': precision} - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): return "TINYINT" - def visit_DATETIMEOFFSET(self, type_): + def visit_DATETIMEOFFSET(self, type_, **kw): if type_.precision: return "DATETIMEOFFSET(%s)" % type_.precision else: return "DATETIMEOFFSET" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision: return "TIME(%s)" % precision else: return "TIME" - def visit_DATETIME2(self, type_): + def visit_DATETIME2(self, type_, **kw): precision = getattr(type_, 'precision', None) if precision: return "DATETIME2(%s)" % precision else: return "DATETIME2" - def visit_SMALLDATETIME(self, type_): + def visit_SMALLDATETIME(self, type_, **kw): return "SMALLDATETIME" - def visit_unicode(self, type_): - return self.visit_NVARCHAR(type_) + def visit_unicode(self, type_, **kw): + return self.visit_NVARCHAR(type_, **kw) - def visit_text(self, type_): + def visit_text(self, type_, **kw): if self.dialect.deprecate_large_types: - return self.visit_VARCHAR(type_) + return self.visit_VARCHAR(type_, **kw) else: - return self.visit_TEXT(type_) + return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_): + def visit_unicode_text(self, type_, **kw): if self.dialect.deprecate_large_types: - return self.visit_NVARCHAR(type_) + return self.visit_NVARCHAR(type_, **kw) else: - return self.visit_NTEXT(type_) + return self.visit_NTEXT(type_, **kw) - def visit_NTEXT(self, type_): + def visit_NTEXT(self, type_, **kw): return self._extend("NTEXT", type_) - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return self._extend("TEXT", type_) - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._extend("VARCHAR", type_, length=type_.length or 'max') - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): return self._extend("CHAR", type_) - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): return self._extend("NCHAR", type_) - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): return self._extend("NVARCHAR", type_, length=type_.length or 'max') - def visit_date(self, type_): + def visit_date(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: - return self.visit_DATETIME(type_) + return self.visit_DATETIME(type_, **kw) else: - return self.visit_DATE(type_) + return self.visit_DATE(type_, **kw) - def visit_time(self, type_): + def visit_time(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: - return self.visit_DATETIME(type_) + return self.visit_DATETIME(type_, **kw) else: - return self.visit_TIME(type_) + return self.visit_TIME(type_, **kw) - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): if self.dialect.deprecate_large_types: - return self.visit_VARBINARY(type_) + return self.visit_VARBINARY(type_, **kw) else: - return self.visit_IMAGE(type_) + return self.visit_IMAGE(type_, **kw) - def visit_IMAGE(self, type_): + def visit_IMAGE(self, type_, **kw): return "IMAGE" - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return self._extend( "VARBINARY", type_, length=type_.length or 'max') - def visit_boolean(self, type_): + def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): return "BIT" - def visit_MONEY(self, type_): + def visit_MONEY(self, type_, **kw): return "MONEY" - def visit_SMALLMONEY(self, type_): + def visit_SMALLMONEY(self, type_, **kw): return 'SMALLMONEY' - def visit_UNIQUEIDENTIFIER(self, type_): + def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" - def visit_SQL_VARIANT(self, type_): + def visit_SQL_VARIANT(self, type_, **kw): return 'SQL_VARIANT' @@ -1240,8 +1239,11 @@ class MSSQLStrictCompiler(MSSQLCompiler): class MSDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = (self.preparer.format_column(column) + " " - + self.dialect.type_compiler.process(column.type)) + colspec = ( + self.preparer.format_column(column) + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column) + ) if column.nullable is not None: if not column.nullable or column.primary_key or \ diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 9c3f23cb2..ca56a4d23 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1859,9 +1859,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kw): """Builds column DDL.""" - colspec = [self.preparer.format_column(column), - self.dialect.type_compiler.process(column.type) - ] + colspec = [ + self.preparer.format_column(column), + self.dialect.type_compiler.process( + column.type, type_expression=column) + ] default = self.get_column_default_string(column) if default is not None: @@ -2059,7 +2061,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: @@ -2072,7 +2074,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): {'precision': type_.precision, 'scale': type_.scale}) - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: @@ -2085,7 +2087,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): {'precision': type_.precision, 'scale': type_.scale}) - def visit_DOUBLE(self, type_): + def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" % @@ -2094,7 +2096,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'DOUBLE') - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" % @@ -2103,7 +2105,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'REAL') - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): if self._mysql_type(type_) and \ type_.scale is not None and \ type_.precision is not None: @@ -2115,7 +2117,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "FLOAT") - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "INTEGER(%(display_width)s)" % @@ -2123,7 +2125,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "INTEGER") - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "BIGINT(%(display_width)s)" % @@ -2131,7 +2133,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "BIGINT") - def visit_MEDIUMINT(self, type_): + def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "MEDIUMINT(%(display_width)s)" % @@ -2139,14 +2141,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "MEDIUMINT") - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width) else: return self._extend_numeric(type_, "TINYINT") - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % @@ -2155,55 +2157,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "SMALLINT") - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): if type_.length is not None: return "BIT(%s)" % type_.length else: return "BIT" - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): if getattr(type_, 'fsp', None): return "DATETIME(%d)" % type_.fsp else: return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): if getattr(type_, 'fsp', None): return "TIME(%d)" % type_.fsp else: return "TIME" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): if getattr(type_, 'fsp', None): return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" - def visit_YEAR(self, type_): + def visit_YEAR(self, type_, **kw): if type_.display_width is None: return "YEAR" else: return "YEAR(%s)" % type_.display_width - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): if type_.length: return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) else: return self._extend_string(type_, {}, "TEXT") - def visit_TINYTEXT(self, type_): + def visit_TINYTEXT(self, type_, **kw): return self._extend_string(type_, {}, "TINYTEXT") - def visit_MEDIUMTEXT(self, type_): + def visit_MEDIUMTEXT(self, type_, **kw): return self._extend_string(type_, {}, "MEDIUMTEXT") - def visit_LONGTEXT(self, type_): + def visit_LONGTEXT(self, type_, **kw): return self._extend_string(type_, {}, "LONGTEXT") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): if type_.length: return self._extend_string( type_, {}, "VARCHAR(%d)" % type_.length) @@ -2212,14 +2214,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "VARCHAR requires a length on dialect %s" % self.dialect.name) - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): if type_.length: return self._extend_string(type_, {}, "CHAR(%(length)s)" % {'length': type_.length}) else: return self._extend_string(type_, {}, "CHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): # We'll actually generate the equiv. "NATIONAL VARCHAR" instead # of "NVARCHAR". if type_.length: @@ -2231,7 +2233,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "NVARCHAR requires a length on dialect %s" % self.dialect.name) - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: @@ -2241,31 +2243,31 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_string(type_, {'national': True}, "CHAR") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY(%d)" % type_.length - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_BLOB(type_) - def visit_enum(self, type_): + def visit_enum(self, type_, **kw): if not type_.native_enum: return super(MySQLTypeCompiler, self).visit_enum(type_) else: return self._visit_enumerated_values("ENUM", type_, type_.enums) - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): if type_.length: return "BLOB(%d)" % type_.length else: return "BLOB" - def visit_TINYBLOB(self, type_): + def visit_TINYBLOB(self, type_, **kw): return "TINYBLOB" - def visit_MEDIUMBLOB(self, type_): + def visit_MEDIUMBLOB(self, type_, **kw): return "MEDIUMBLOB" - def visit_LONGBLOB(self, type_): + def visit_LONGBLOB(self, type_, **kw): return "LONGBLOB" def _visit_enumerated_values(self, name, type_, enumerated_values): @@ -2276,15 +2278,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): name, ",".join(quoted_enums)) ) - def visit_ENUM(self, type_): + def visit_ENUM(self, type_, **kw): return self._visit_enumerated_values("ENUM", type_, type_._enumerated_values) - def visit_SET(self, type_): + def visit_SET(self, type_, **kw): return self._visit_enumerated_values("SET", type_, type_._enumerated_values) - def visit_BOOLEAN(self, type): + def visit_BOOLEAN(self, type, **kw): return "BOOL" diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 9f375da94..b482c9069 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -457,19 +457,19 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): # Oracle does not allow milliseconds in DATE # Oracle does not support TIME columns - def visit_datetime(self, type_): - return self.visit_DATE(type_) + def visit_datetime(self, type_, **kw): + return self.visit_DATE(type_, **kw) - def visit_float(self, type_): - return self.visit_FLOAT(type_) + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) - def visit_unicode(self, type_): + def visit_unicode(self, type_, **kw): if self.dialect._supports_nchar: - return self.visit_NVARCHAR2(type_) + return self.visit_NVARCHAR2(type_, **kw) else: - return self.visit_VARCHAR2(type_) + return self.visit_VARCHAR2(type_, **kw) - def visit_INTERVAL(self, type_): + def visit_INTERVAL(self, type_, **kw): return "INTERVAL DAY%s TO SECOND%s" % ( type_.day_precision is not None and "(%d)" % type_.day_precision or @@ -479,22 +479,22 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): "", ) - def visit_LONG(self, type_): + def visit_LONG(self, type_, **kw): return "LONG" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): if type_.timezone: return "TIMESTAMP WITH TIME ZONE" else: return "TIMESTAMP" - def visit_DOUBLE_PRECISION(self, type_): - return self._generate_numeric(type_, "DOUBLE PRECISION") + def visit_DOUBLE_PRECISION(self, type_, **kw): + return self._generate_numeric(type_, "DOUBLE PRECISION", **kw) def visit_NUMBER(self, type_, **kw): return self._generate_numeric(type_, "NUMBER", **kw) - def _generate_numeric(self, type_, name, precision=None, scale=None): + def _generate_numeric(self, type_, name, precision=None, scale=None, **kw): if precision is None: precision = type_.precision @@ -510,17 +510,17 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): n = "%(name)s(%(precision)s, %(scale)s)" return n % {'name': name, 'precision': precision, 'scale': scale} - def visit_string(self, type_): - return self.visit_VARCHAR2(type_) + def visit_string(self, type_, **kw): + return self.visit_VARCHAR2(type_, **kw) - def visit_VARCHAR2(self, type_): + def visit_VARCHAR2(self, type_, **kw): return self._visit_varchar(type_, '', '2') - def visit_NVARCHAR2(self, type_): + def visit_NVARCHAR2(self, type_, **kw): return self._visit_varchar(type_, 'N', '2') visit_NVARCHAR = visit_NVARCHAR2 - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._visit_varchar(type_, '', '') def _visit_varchar(self, type_, n, num): @@ -533,31 +533,31 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): varchar = "%(n)sVARCHAR%(two)s(%(length)s)" return varchar % {'length': type_.length, 'two': num, 'n': n} - def visit_text(self, type_): - return self.visit_CLOB(type_) + def visit_text(self, type_, **kw): + return self.visit_CLOB(type_, **kw) - def visit_unicode_text(self, type_): + def visit_unicode_text(self, type_, **kw): if self.dialect._supports_nchar: - return self.visit_NCLOB(type_) + return self.visit_NCLOB(type_, **kw) else: - return self.visit_CLOB(type_) + return self.visit_CLOB(type_, **kw) - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) - def visit_big_integer(self, type_): - return self.visit_NUMBER(type_, precision=19) + def visit_big_integer(self, type_, **kw): + return self.visit_NUMBER(type_, precision=19, **kw) - def visit_boolean(self, type_): - return self.visit_SMALLINT(type_) + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_RAW(self, type_): + def visit_RAW(self, type_, **kw): if type_.length: return "RAW(%(length)s)" % {'length': type_.length} else: return "RAW" - def visit_ROWID(self, type_): + def visit_ROWID(self, type_, **kw): return "ROWID" diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 0817fe837..89bea100e 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1425,7 +1425,8 @@ class PGDDLCompiler(compiler.DDLCompiler): else: colspec += " SERIAL" else: - colspec += " " + self.dialect.type_compiler.process(column.type) + colspec += " " + self.dialect.type_compiler.process(column.type, + type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -1545,94 +1546,93 @@ class PGDDLCompiler(compiler.DDLCompiler): class PGTypeCompiler(compiler.GenericTypeCompiler): - - def visit_TSVECTOR(self, type): + def visit_TSVECTOR(self, type, **kw): return "TSVECTOR" - def visit_INET(self, type_): + def visit_INET(self, type_, **kw): return "INET" - def visit_CIDR(self, type_): + def visit_CIDR(self, type_, **kw): return "CIDR" - def visit_MACADDR(self, type_): + def visit_MACADDR(self, type_, **kw): return "MACADDR" - def visit_OID(self, type_): + def visit_OID(self, type_, **kw): return "OID" - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): if not type_.precision: return "FLOAT" else: return "FLOAT(%(precision)s)" % {'precision': type_.precision} - def visit_DOUBLE_PRECISION(self, type_): + def visit_DOUBLE_PRECISION(self, type_, **kw): return "DOUBLE PRECISION" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_HSTORE(self, type_): + def visit_HSTORE(self, type_, **kw): return "HSTORE" - def visit_JSON(self, type_): + def visit_JSON(self, type_, **kw): return "JSON" - def visit_JSONB(self, type_): + def visit_JSONB(self, type_, **kw): return "JSONB" - def visit_INT4RANGE(self, type_): + def visit_INT4RANGE(self, type_, **kw): return "INT4RANGE" - def visit_INT8RANGE(self, type_): + def visit_INT8RANGE(self, type_, **kw): return "INT8RANGE" - def visit_NUMRANGE(self, type_): + def visit_NUMRANGE(self, type_, **kw): return "NUMRANGE" - def visit_DATERANGE(self, type_): + def visit_DATERANGE(self, type_, **kw): return "DATERANGE" - def visit_TSRANGE(self, type_): + def visit_TSRANGE(self, type_, **kw): return "TSRANGE" - def visit_TSTZRANGE(self, type_): + def visit_TSTZRANGE(self, type_, **kw): return "TSTZRANGE" - def visit_datetime(self, type_): - return self.visit_TIMESTAMP(type_) + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) - def visit_enum(self, type_): + def visit_enum(self, type_, **kw): if not type_.native_enum or not self.dialect.supports_native_enum: - return super(PGTypeCompiler, self).visit_enum(type_) + return super(PGTypeCompiler, self).visit_enum(type_, **kw) else: - return self.visit_ENUM(type_) + return self.visit_ENUM(type_, **kw) - def visit_ENUM(self, type_): + def visit_ENUM(self, type_, **kw): return self.dialect.identifier_preparer.format_type(type_) - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" ) - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" ) - def visit_INTERVAL(self, type_): + def visit_INTERVAL(self, type_, **kw): if type_.precision is not None: return "INTERVAL(%d)" % type_.precision else: return "INTERVAL" - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): if type_.varying: compiled = "BIT VARYING" if type_.length is not None: @@ -1641,16 +1641,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): compiled = "BIT(%d)" % type_.length return compiled - def visit_UUID(self, type_): + def visit_UUID(self, type_, **kw): return "UUID" - def visit_large_binary(self, type_): - return self.visit_BYTEA(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BYTEA(type_, **kw) - def visit_BYTEA(self, type_): + def visit_BYTEA(self, type_, **kw): return "BYTEA" - def visit_ARRAY(self, type_): + def visit_ARRAY(self, type_, **kw): return self.process(type_.item_type) + ('[]' * (type_.dimensions if type_.dimensions is not None else 1)) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 3d7b0788b..f74421967 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -660,7 +660,8 @@ class SQLiteCompiler(compiler.SQLCompiler): class SQLiteDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - coltype = self.dialect.type_compiler.process(column.type) + coltype = self.dialect.type_compiler.process( + column.type, type_expression=column) colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: @@ -716,24 +717,24 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): class SQLiteTypeCompiler(compiler.GenericTypeCompiler): - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_BLOB(type_) - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): if not isinstance(type_, _DateTimeMixin) or \ type_.format_is_text_affinity: return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) else: return "DATETIME_CHAR" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): if not isinstance(type_, _DateTimeMixin) or \ type_.format_is_text_affinity: return super(SQLiteTypeCompiler, self).visit_DATE(type_) else: return "DATE_CHAR" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): if not isinstance(type_, _DateTimeMixin) or \ type_.format_is_text_affinity: return super(SQLiteTypeCompiler, self).visit_TIME(type_) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index f65a76a27..369420358 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -146,40 +146,40 @@ class IMAGE(sqltypes.LargeBinary): class SybaseTypeCompiler(compiler.GenericTypeCompiler): - def visit_large_binary(self, type_): + def visit_large_binary(self, type_, **kw): return self.visit_IMAGE(type_) - def visit_boolean(self, type_): + def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) - def visit_unicode(self, type_): + def visit_unicode(self, type_, **kw): return self.visit_NVARCHAR(type_) - def visit_UNICHAR(self, type_): + def visit_UNICHAR(self, type_, **kw): return "UNICHAR(%d)" % type_.length - def visit_UNIVARCHAR(self, type_): + def visit_UNIVARCHAR(self, type_, **kw): return "UNIVARCHAR(%d)" % type_.length - def visit_UNITEXT(self, type_): + def visit_UNITEXT(self, type_, **kw): return "UNITEXT" - def visit_TINYINT(self, type_): + def visit_TINYINT(self, type_, **kw): return "TINYINT" - def visit_IMAGE(self, type_): + def visit_IMAGE(self, type_, **kw): return "IMAGE" - def visit_BIT(self, type_): + def visit_BIT(self, type_, **kw): return "BIT" - def visit_MONEY(self, type_): + def visit_MONEY(self, type_, **kw): return "MONEY" - def visit_SMALLMONEY(self, type_): + def visit_SMALLMONEY(self, type_, **kw): return "SMALLMONEY" - def visit_UNIQUEIDENTIFIER(self, type_): + def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" ischema_names = { @@ -377,7 +377,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler): class SybaseDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) if column.table is None: raise exc.CompileError( diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ca14c9371..da62b1434 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -248,15 +248,16 @@ class Compiled(object): return self.execute(*multiparams, **params).scalar() -class TypeCompiler(object): - +class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" + ensure_kwarg = 'visit_\w+' + def __init__(self, dialect): self.dialect = dialect - def process(self, type_): - return type_._compiler_dispatch(self) + def process(self, type_, **kw): + return type_._compiler_dispatch(self, **kw) class _CompileLabel(visitors.Visitable): @@ -638,8 +639,9 @@ class SQLCompiler(Compiled): def visit_index(self, index, **kwargs): return index.name - def visit_typeclause(self, typeclause, **kwargs): - return self.dialect.type_compiler.process(typeclause.type) + def visit_typeclause(self, typeclause, **kw): + kw['type_expression'] = typeclause + return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): return text @@ -2259,7 +2261,8 @@ class DDLCompiler(Compiled): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2383,13 +2386,13 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): return "FLOAT" - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): return "REAL" - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -2400,7 +2403,7 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: @@ -2411,31 +2414,31 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): return "INTEGER" - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): return "SMALLINT" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return 'TIMESTAMP' - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME" - def visit_CLOB(self, type_): + def visit_CLOB(self, type_, **kw): return "CLOB" - def visit_NCLOB(self, type_): + def visit_NCLOB(self, type_, **kw): return "NCLOB" def _render_string_type(self, type_, name): @@ -2447,91 +2450,91 @@ class GenericTypeCompiler(TypeCompiler): text += ' COLLATE "%s"' % type_.collation return text - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): return self._render_string_type(type_, "CHAR") - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): return self._render_string_type(type_, "NCHAR") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._render_string_type(type_, "VARCHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): return self._render_string_type(type_, "NVARCHAR") - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return self._render_string_type(type_, "TEXT") - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): return "BLOB" - def visit_BINARY(self, type_): + def visit_BINARY(self, type_, **kw): return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_): + def visit_BOOLEAN(self, type_, **kw): return "BOOLEAN" - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_): - return self.visit_BOOLEAN(type_) + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_): - return self.visit_TIME(type_) + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_): - return self.visit_DATETIME(type_) + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_): - return self.visit_DATE(type_) + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) - def visit_big_integer(self, type_): - return self.visit_BIGINT(type_) + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) - def visit_small_integer(self, type_): - return self.visit_SMALLINT(type_) + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_): - return self.visit_INTEGER(type_) + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) - def visit_real(self, type_): - return self.visit_REAL(type_) + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) - def visit_float(self, type_): - return self.visit_FLOAT(type_) + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) - def visit_numeric(self, type_): - return self.visit_NUMERIC(type_) + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_): - return self.visit_VARCHAR(type_) + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_): - return self.visit_VARCHAR(type_) + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_): - return self.visit_TEXT(type_) + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_): - return self.visit_TEXT(type_) + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_): - return self.visit_VARCHAR(type_) + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_null(self, type_): + def visit_null(self, type_, **kw): raise exc.CompileError("Can't generate DDL for %r; " "did you forget to specify a " "type on this Column?" % type_) - def visit_type_decorator(self, type_): - return self.process(type_.type_engine(self.dialect)) + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) - def visit_user_defined(self, type_): - return type_.get_col_spec() + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) class IdentifierPreparer(object): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index bff497800..19398ae96 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -12,7 +12,7 @@ from .. import exc, util from . import operators -from .visitors import Visitable +from .visitors import Visitable, VisitableType # these are back-assigned by sqltypes. BOOLEANTYPE = None @@ -460,7 +460,11 @@ class TypeEngine(Visitable): return util.generic_repr(self) -class UserDefinedType(TypeEngine): +class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType): + pass + + +class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)): """Base for user defined types. This should be the base of new types. Note that @@ -473,7 +477,7 @@ class UserDefinedType(TypeEngine): def __init__(self, precision = 8): self.precision = precision - def get_col_spec(self): + def get_col_spec(self, **kw): return "MYTYPE(%s)" % self.precision def bind_processor(self, dialect): @@ -493,9 +497,23 @@ class UserDefinedType(TypeEngine): Column('data', MyType(16)) ) + The ``get_col_spec()`` method will in most cases receive a keyword + argument ``type_expression`` which refers to the owning expression + of the type as being compiled, such as a :class:`.Column` or + :func:`.cast` construct. This keyword is only sent if the method + accepts keyword arguments (e.g. ``**kw``) in its argument signature; + introspection is used to check for this in order to support legacy + forms of this function. + + .. versionadded:: 1.0.0 the owning expression is passed to + the ``get_col_spec()`` method via the keyword argument + ``type_expression``, if it receives ``**kw`` in its signature. + """ __visit_name__ = "user_defined" + ensure_kwarg = 'get_col_spec' + class Comparator(TypeEngine.Comparator): __slots__ = () diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index c23b0196f..ceee18d86 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -36,7 +36,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \ generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \ safe_reraise,\ get_callable_argspec, only_once, attrsetter, ellipses_string, \ - warn_limited, map_bits, MemoizedSlots + warn_limited, map_bits, MemoizedSlots, EnsureKWArgType from .deprecations import warn_deprecated, warn_pending_deprecation, \ deprecated, pending_deprecation, inject_docstring_text diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 22b6ad4ca..5a938501a 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1348,6 +1348,7 @@ def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE): NoneType = type(None) + def attrsetter(attrname): code = \ "def set(obj, value):"\ @@ -1355,3 +1356,29 @@ def attrsetter(attrname): env = locals().copy() exec(code, env) return env['set'] + + +class EnsureKWArgType(type): + """Apply translation of functions to accept **kw arguments if they + don't already. + + """ + def __init__(cls, clsname, bases, clsdict): + fn_reg = cls.ensure_kwarg + if fn_reg: + for key in clsdict: + m = re.match(fn_reg, key) + if m: + fn = clsdict[key] + spec = inspect.getargspec(fn) + if not spec.keywords: + clsdict[key] = wrapped = cls._wrap_w_kw(fn) + setattr(cls, key, wrapped) + super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict) + + def _wrap_w_kw(self, fn): + + def wrap(*arg, **kw): + return fn(*arg) + return update_wrapper(wrap, fn) + diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 6ffd88d78..38b3ced13 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -10,6 +10,8 @@ from sqlalchemy import ( type_coerce, VARCHAR, Time, DateTime, BigInteger, SmallInteger, BOOLEAN, BLOB, NCHAR, NVARCHAR, CLOB, TIME, DATE, DATETIME, TIMESTAMP, SMALLINT, INTEGER, DECIMAL, NUMERIC, FLOAT, REAL) +from sqlalchemy.sql import ddl + from sqlalchemy import exc, types, util, dialects for name in dialects.__all__: __import__("sqlalchemy.dialects.%s" % name) @@ -309,6 +311,24 @@ class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL): literal_binds=True ) + def test_kw_colspec(self): + class MyType(types.UserDefinedType): + def get_col_spec(self, **kw): + return "FOOB %s" % kw['type_expression'].name + + class MyOtherType(types.UserDefinedType): + def get_col_spec(self): + return "BAR" + + self.assert_compile( + ddl.CreateColumn(Column('bar', MyType)), + "bar FOOB bar" + ) + self.assert_compile( + ddl.CreateColumn(Column('bar', MyOtherType)), + "bar BAR" + ) + def test_typedecorator_literal_render_fallback_bound(self): # fall back to process_bind_param for literal # value rendering. @@ -1642,6 +1662,49 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_decimal_scale(self): self.assert_compile(types.DECIMAL(2, 4), 'DECIMAL(2, 4)') + def test_kwarg_legacy_typecompiler(self): + from sqlalchemy.sql import compiler + + class SomeTypeCompiler(compiler.GenericTypeCompiler): + # transparently decorated w/ kw decorator + def visit_VARCHAR(self, type_): + return "MYVARCHAR" + + # not affected + def visit_INTEGER(self, type_, **kw): + return "MYINTEGER %s" % kw['type_expression'].name + + dialect = default.DefaultDialect() + dialect.type_compiler = SomeTypeCompiler(dialect) + self.assert_compile( + ddl.CreateColumn(Column('bar', VARCHAR(50))), + "bar MYVARCHAR", + dialect=dialect + ) + self.assert_compile( + ddl.CreateColumn(Column('bar', INTEGER)), + "bar MYINTEGER bar", + dialect=dialect + ) + + +class TestKWArgPassThru(AssertsCompiledSQL, fixtures.TestBase): + __backend__ = True + + def test_user_defined(self): + """test that dialects pass the column through on DDL.""" + + class MyType(types.UserDefinedType): + def get_col_spec(self, **kw): + return "FOOB %s" % kw['type_expression'].name + + m = MetaData() + t = Table('t', m, Column('bar', MyType)) + self.assert_compile( + ddl.CreateColumn(t.c.bar), + "bar FOOB bar" + ) + class NumericRawSQLTest(fixtures.TestBase): |