summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/changelog_10.rst16
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py20
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py82
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py76
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py62
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py68
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py11
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py27
-rw-r--r--lib/sqlalchemy/sql/compiler.py135
-rw-r--r--lib/sqlalchemy/sql/type_api.py24
-rw-r--r--lib/sqlalchemy/util/__init__.py2
-rw-r--r--lib/sqlalchemy/util/langhelpers.py27
-rw-r--r--test/sql/test_types.py63
13 files changed, 373 insertions, 240 deletions
diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst
index 5d8bb7b68..089c9fafb 100644
--- a/doc/build/changelog/changelog_10.rst
+++ b/doc/build/changelog/changelog_10.rst
@@ -23,6 +23,22 @@
on compatibility concerns, see :doc:`/changelog/migration_10`.
.. change::
+ :tags: enhancement, sql
+ :tickets: 3074
+
+ Custom dialects that implement :class:`.GenericTypeCompiler` can
+ now be constructed such that the visit methods receive an indication
+ of the owning expression object, if any. Any visit method that
+ accepts keyword arguments (e.g. ``**kw``) will in most cases
+ receive a keyword argument ``type_expression``, referring to the
+ expression object that the type is contained within. For columns
+ in DDL, the dialect's compiler class may need to alter its
+ ``get_column_specification()`` method to support this as well.
+ The ``UserDefinedType.get_col_spec()`` method will also receive
+ ``type_expression`` if it provides ``**kw`` in its argument
+ signature.
+
+ .. change::
:tags: bug, sql
:tickets: 3288
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
index 36229a105..74e8abfc2 100644
--- a/lib/sqlalchemy/dialects/firebird/base.py
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -180,16 +180,16 @@ ischema_names = {
# _FBDate, etc. as bind/result functionality is required)
class FBTypeCompiler(compiler.GenericTypeCompiler):
- def visit_boolean(self, type_):
- return self.visit_SMALLINT(type_)
+ def visit_boolean(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
- def visit_datetime(self, type_):
- return self.visit_TIMESTAMP(type_)
+ def visit_datetime(self, type_, **kw):
+ return self.visit_TIMESTAMP(type_, **kw)
- def visit_TEXT(self, type_):
+ def visit_TEXT(self, type_, **kw):
return "BLOB SUB_TYPE 1"
- def visit_BLOB(self, type_):
+ def visit_BLOB(self, type_, **kw):
return "BLOB SUB_TYPE 0"
def _extend_string(self, type_, basic):
@@ -199,16 +199,16 @@ class FBTypeCompiler(compiler.GenericTypeCompiler):
else:
return '%s CHARACTER SET %s' % (basic, charset)
- def visit_CHAR(self, type_):
- basic = super(FBTypeCompiler, self).visit_CHAR(type_)
+ def visit_CHAR(self, type_, **kw):
+ basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw)
return self._extend_string(type_, basic)
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
if not type_.length:
raise exc.CompileError(
"VARCHAR requires a length on dialect %s" %
self.dialect.name)
- basic = super(FBTypeCompiler, self).visit_VARCHAR(type_)
+ basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw)
return self._extend_string(type_, basic)
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 5d84975c0..92d7e4ab3 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -694,7 +694,6 @@ ischema_names = {
class MSTypeCompiler(compiler.GenericTypeCompiler):
-
def _extend(self, spec, type_, length=None):
"""Extend a string-type declaration with standard SQL
COLLATE annotations.
@@ -715,115 +714,115 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return ' '.join([c for c in (spec, collation)
if c is not None])
- def visit_FLOAT(self, type_):
+ def visit_FLOAT(self, type_, **kw):
precision = getattr(type_, 'precision', None)
if precision is None:
return "FLOAT"
else:
return "FLOAT(%(precision)s)" % {'precision': precision}
- def visit_TINYINT(self, type_):
+ def visit_TINYINT(self, type_, **kw):
return "TINYINT"
- def visit_DATETIMEOFFSET(self, type_):
+ def visit_DATETIMEOFFSET(self, type_, **kw):
if type_.precision:
return "DATETIMEOFFSET(%s)" % type_.precision
else:
return "DATETIMEOFFSET"
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
precision = getattr(type_, 'precision', None)
if precision:
return "TIME(%s)" % precision
else:
return "TIME"
- def visit_DATETIME2(self, type_):
+ def visit_DATETIME2(self, type_, **kw):
precision = getattr(type_, 'precision', None)
if precision:
return "DATETIME2(%s)" % precision
else:
return "DATETIME2"
- def visit_SMALLDATETIME(self, type_):
+ def visit_SMALLDATETIME(self, type_, **kw):
return "SMALLDATETIME"
- def visit_unicode(self, type_):
- return self.visit_NVARCHAR(type_)
+ def visit_unicode(self, type_, **kw):
+ return self.visit_NVARCHAR(type_, **kw)
- def visit_text(self, type_):
+ def visit_text(self, type_, **kw):
if self.dialect.deprecate_large_types:
- return self.visit_VARCHAR(type_)
+ return self.visit_VARCHAR(type_, **kw)
else:
- return self.visit_TEXT(type_)
+ return self.visit_TEXT(type_, **kw)
- def visit_unicode_text(self, type_):
+ def visit_unicode_text(self, type_, **kw):
if self.dialect.deprecate_large_types:
- return self.visit_NVARCHAR(type_)
+ return self.visit_NVARCHAR(type_, **kw)
else:
- return self.visit_NTEXT(type_)
+ return self.visit_NTEXT(type_, **kw)
- def visit_NTEXT(self, type_):
+ def visit_NTEXT(self, type_, **kw):
return self._extend("NTEXT", type_)
- def visit_TEXT(self, type_):
+ def visit_TEXT(self, type_, **kw):
return self._extend("TEXT", type_)
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
return self._extend("VARCHAR", type_, length=type_.length or 'max')
- def visit_CHAR(self, type_):
+ def visit_CHAR(self, type_, **kw):
return self._extend("CHAR", type_)
- def visit_NCHAR(self, type_):
+ def visit_NCHAR(self, type_, **kw):
return self._extend("NCHAR", type_)
- def visit_NVARCHAR(self, type_):
+ def visit_NVARCHAR(self, type_, **kw):
return self._extend("NVARCHAR", type_, length=type_.length or 'max')
- def visit_date(self, type_):
+ def visit_date(self, type_, **kw):
if self.dialect.server_version_info < MS_2008_VERSION:
- return self.visit_DATETIME(type_)
+ return self.visit_DATETIME(type_, **kw)
else:
- return self.visit_DATE(type_)
+ return self.visit_DATE(type_, **kw)
- def visit_time(self, type_):
+ def visit_time(self, type_, **kw):
if self.dialect.server_version_info < MS_2008_VERSION:
- return self.visit_DATETIME(type_)
+ return self.visit_DATETIME(type_, **kw)
else:
- return self.visit_TIME(type_)
+ return self.visit_TIME(type_, **kw)
- def visit_large_binary(self, type_):
+ def visit_large_binary(self, type_, **kw):
if self.dialect.deprecate_large_types:
- return self.visit_VARBINARY(type_)
+ return self.visit_VARBINARY(type_, **kw)
else:
- return self.visit_IMAGE(type_)
+ return self.visit_IMAGE(type_, **kw)
- def visit_IMAGE(self, type_):
+ def visit_IMAGE(self, type_, **kw):
return "IMAGE"
- def visit_VARBINARY(self, type_):
+ def visit_VARBINARY(self, type_, **kw):
return self._extend(
"VARBINARY",
type_,
length=type_.length or 'max')
- def visit_boolean(self, type_):
+ def visit_boolean(self, type_, **kw):
return self.visit_BIT(type_)
- def visit_BIT(self, type_):
+ def visit_BIT(self, type_, **kw):
return "BIT"
- def visit_MONEY(self, type_):
+ def visit_MONEY(self, type_, **kw):
return "MONEY"
- def visit_SMALLMONEY(self, type_):
+ def visit_SMALLMONEY(self, type_, **kw):
return 'SMALLMONEY'
- def visit_UNIQUEIDENTIFIER(self, type_):
+ def visit_UNIQUEIDENTIFIER(self, type_, **kw):
return "UNIQUEIDENTIFIER"
- def visit_SQL_VARIANT(self, type_):
+ def visit_SQL_VARIANT(self, type_, **kw):
return 'SQL_VARIANT'
@@ -1240,8 +1239,11 @@ class MSSQLStrictCompiler(MSSQLCompiler):
class MSDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
- colspec = (self.preparer.format_column(column) + " "
- + self.dialect.type_compiler.process(column.type))
+ colspec = (
+ self.preparer.format_column(column) + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column)
+ )
if column.nullable is not None:
if not column.nullable or column.primary_key or \
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 9c3f23cb2..ca56a4d23 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1859,9 +1859,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kw):
"""Builds column DDL."""
- colspec = [self.preparer.format_column(column),
- self.dialect.type_compiler.process(column.type)
- ]
+ colspec = [
+ self.preparer.format_column(column),
+ self.dialect.type_compiler.process(
+ column.type, type_expression=column)
+ ]
default = self.get_column_default_string(column)
if default is not None:
@@ -2059,7 +2061,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def _mysql_type(self, type_):
return isinstance(type_, (_StringType, _NumericType))
- def visit_NUMERIC(self, type_):
+ def visit_NUMERIC(self, type_, **kw):
if type_.precision is None:
return self._extend_numeric(type_, "NUMERIC")
elif type_.scale is None:
@@ -2072,7 +2074,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
{'precision': type_.precision,
'scale': type_.scale})
- def visit_DECIMAL(self, type_):
+ def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return self._extend_numeric(type_, "DECIMAL")
elif type_.scale is None:
@@ -2085,7 +2087,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
{'precision': type_.precision,
'scale': type_.scale})
- def visit_DOUBLE(self, type_):
+ def visit_DOUBLE(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
return self._extend_numeric(type_,
"DOUBLE(%(precision)s, %(scale)s)" %
@@ -2094,7 +2096,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, 'DOUBLE')
- def visit_REAL(self, type_):
+ def visit_REAL(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
return self._extend_numeric(type_,
"REAL(%(precision)s, %(scale)s)" %
@@ -2103,7 +2105,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, 'REAL')
- def visit_FLOAT(self, type_):
+ def visit_FLOAT(self, type_, **kw):
if self._mysql_type(type_) and \
type_.scale is not None and \
type_.precision is not None:
@@ -2115,7 +2117,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "FLOAT")
- def visit_INTEGER(self, type_):
+ def visit_INTEGER(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "INTEGER(%(display_width)s)" %
@@ -2123,7 +2125,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "INTEGER")
- def visit_BIGINT(self, type_):
+ def visit_BIGINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "BIGINT(%(display_width)s)" %
@@ -2131,7 +2133,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "BIGINT")
- def visit_MEDIUMINT(self, type_):
+ def visit_MEDIUMINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "MEDIUMINT(%(display_width)s)" %
@@ -2139,14 +2141,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "MEDIUMINT")
- def visit_TINYINT(self, type_):
+ def visit_TINYINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(type_,
"TINYINT(%s)" % type_.display_width)
else:
return self._extend_numeric(type_, "TINYINT")
- def visit_SMALLINT(self, type_):
+ def visit_SMALLINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(type_,
"SMALLINT(%(display_width)s)" %
@@ -2155,55 +2157,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "SMALLINT")
- def visit_BIT(self, type_):
+ def visit_BIT(self, type_, **kw):
if type_.length is not None:
return "BIT(%s)" % type_.length
else:
return "BIT"
- def visit_DATETIME(self, type_):
+ def visit_DATETIME(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "DATETIME(%d)" % type_.fsp
else:
return "DATETIME"
- def visit_DATE(self, type_):
+ def visit_DATE(self, type_, **kw):
return "DATE"
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "TIME(%d)" % type_.fsp
else:
return "TIME"
- def visit_TIMESTAMP(self, type_):
+ def visit_TIMESTAMP(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "TIMESTAMP(%d)" % type_.fsp
else:
return "TIMESTAMP"
- def visit_YEAR(self, type_):
+ def visit_YEAR(self, type_, **kw):
if type_.display_width is None:
return "YEAR"
else:
return "YEAR(%s)" % type_.display_width
- def visit_TEXT(self, type_):
+ def visit_TEXT(self, type_, **kw):
if type_.length:
return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
else:
return self._extend_string(type_, {}, "TEXT")
- def visit_TINYTEXT(self, type_):
+ def visit_TINYTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "TINYTEXT")
- def visit_MEDIUMTEXT(self, type_):
+ def visit_MEDIUMTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "MEDIUMTEXT")
- def visit_LONGTEXT(self, type_):
+ def visit_LONGTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "LONGTEXT")
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
if type_.length:
return self._extend_string(
type_, {}, "VARCHAR(%d)" % type_.length)
@@ -2212,14 +2214,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
"VARCHAR requires a length on dialect %s" %
self.dialect.name)
- def visit_CHAR(self, type_):
+ def visit_CHAR(self, type_, **kw):
if type_.length:
return self._extend_string(type_, {}, "CHAR(%(length)s)" %
{'length': type_.length})
else:
return self._extend_string(type_, {}, "CHAR")
- def visit_NVARCHAR(self, type_):
+ def visit_NVARCHAR(self, type_, **kw):
# We'll actually generate the equiv. "NATIONAL VARCHAR" instead
# of "NVARCHAR".
if type_.length:
@@ -2231,7 +2233,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
"NVARCHAR requires a length on dialect %s" %
self.dialect.name)
- def visit_NCHAR(self, type_):
+ def visit_NCHAR(self, type_, **kw):
# We'll actually generate the equiv.
# "NATIONAL CHAR" instead of "NCHAR".
if type_.length:
@@ -2241,31 +2243,31 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_string(type_, {'national': True}, "CHAR")
- def visit_VARBINARY(self, type_):
+ def visit_VARBINARY(self, type_, **kw):
return "VARBINARY(%d)" % type_.length
- def visit_large_binary(self, type_):
+ def visit_large_binary(self, type_, **kw):
return self.visit_BLOB(type_)
- def visit_enum(self, type_):
+ def visit_enum(self, type_, **kw):
if not type_.native_enum:
return super(MySQLTypeCompiler, self).visit_enum(type_)
else:
return self._visit_enumerated_values("ENUM", type_, type_.enums)
- def visit_BLOB(self, type_):
+ def visit_BLOB(self, type_, **kw):
if type_.length:
return "BLOB(%d)" % type_.length
else:
return "BLOB"
- def visit_TINYBLOB(self, type_):
+ def visit_TINYBLOB(self, type_, **kw):
return "TINYBLOB"
- def visit_MEDIUMBLOB(self, type_):
+ def visit_MEDIUMBLOB(self, type_, **kw):
return "MEDIUMBLOB"
- def visit_LONGBLOB(self, type_):
+ def visit_LONGBLOB(self, type_, **kw):
return "LONGBLOB"
def _visit_enumerated_values(self, name, type_, enumerated_values):
@@ -2276,15 +2278,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
name, ",".join(quoted_enums))
)
- def visit_ENUM(self, type_):
+ def visit_ENUM(self, type_, **kw):
return self._visit_enumerated_values("ENUM", type_,
type_._enumerated_values)
- def visit_SET(self, type_):
+ def visit_SET(self, type_, **kw):
return self._visit_enumerated_values("SET", type_,
type_._enumerated_values)
- def visit_BOOLEAN(self, type):
+ def visit_BOOLEAN(self, type, **kw):
return "BOOL"
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 9f375da94..b482c9069 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -457,19 +457,19 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
# Oracle does not allow milliseconds in DATE
# Oracle does not support TIME columns
- def visit_datetime(self, type_):
- return self.visit_DATE(type_)
+ def visit_datetime(self, type_, **kw):
+ return self.visit_DATE(type_, **kw)
- def visit_float(self, type_):
- return self.visit_FLOAT(type_)
+ def visit_float(self, type_, **kw):
+ return self.visit_FLOAT(type_, **kw)
- def visit_unicode(self, type_):
+ def visit_unicode(self, type_, **kw):
if self.dialect._supports_nchar:
- return self.visit_NVARCHAR2(type_)
+ return self.visit_NVARCHAR2(type_, **kw)
else:
- return self.visit_VARCHAR2(type_)
+ return self.visit_VARCHAR2(type_, **kw)
- def visit_INTERVAL(self, type_):
+ def visit_INTERVAL(self, type_, **kw):
return "INTERVAL DAY%s TO SECOND%s" % (
type_.day_precision is not None and
"(%d)" % type_.day_precision or
@@ -479,22 +479,22 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
"",
)
- def visit_LONG(self, type_):
+ def visit_LONG(self, type_, **kw):
return "LONG"
- def visit_TIMESTAMP(self, type_):
+ def visit_TIMESTAMP(self, type_, **kw):
if type_.timezone:
return "TIMESTAMP WITH TIME ZONE"
else:
return "TIMESTAMP"
- def visit_DOUBLE_PRECISION(self, type_):
- return self._generate_numeric(type_, "DOUBLE PRECISION")
+ def visit_DOUBLE_PRECISION(self, type_, **kw):
+ return self._generate_numeric(type_, "DOUBLE PRECISION", **kw)
def visit_NUMBER(self, type_, **kw):
return self._generate_numeric(type_, "NUMBER", **kw)
- def _generate_numeric(self, type_, name, precision=None, scale=None):
+ def _generate_numeric(self, type_, name, precision=None, scale=None, **kw):
if precision is None:
precision = type_.precision
@@ -510,17 +510,17 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
n = "%(name)s(%(precision)s, %(scale)s)"
return n % {'name': name, 'precision': precision, 'scale': scale}
- def visit_string(self, type_):
- return self.visit_VARCHAR2(type_)
+ def visit_string(self, type_, **kw):
+ return self.visit_VARCHAR2(type_, **kw)
- def visit_VARCHAR2(self, type_):
+ def visit_VARCHAR2(self, type_, **kw):
return self._visit_varchar(type_, '', '2')
- def visit_NVARCHAR2(self, type_):
+ def visit_NVARCHAR2(self, type_, **kw):
return self._visit_varchar(type_, 'N', '2')
visit_NVARCHAR = visit_NVARCHAR2
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
return self._visit_varchar(type_, '', '')
def _visit_varchar(self, type_, n, num):
@@ -533,31 +533,31 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
varchar = "%(n)sVARCHAR%(two)s(%(length)s)"
return varchar % {'length': type_.length, 'two': num, 'n': n}
- def visit_text(self, type_):
- return self.visit_CLOB(type_)
+ def visit_text(self, type_, **kw):
+ return self.visit_CLOB(type_, **kw)
- def visit_unicode_text(self, type_):
+ def visit_unicode_text(self, type_, **kw):
if self.dialect._supports_nchar:
- return self.visit_NCLOB(type_)
+ return self.visit_NCLOB(type_, **kw)
else:
- return self.visit_CLOB(type_)
+ return self.visit_CLOB(type_, **kw)
- def visit_large_binary(self, type_):
- return self.visit_BLOB(type_)
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_, **kw)
- def visit_big_integer(self, type_):
- return self.visit_NUMBER(type_, precision=19)
+ def visit_big_integer(self, type_, **kw):
+ return self.visit_NUMBER(type_, precision=19, **kw)
- def visit_boolean(self, type_):
- return self.visit_SMALLINT(type_)
+ def visit_boolean(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
- def visit_RAW(self, type_):
+ def visit_RAW(self, type_, **kw):
if type_.length:
return "RAW(%(length)s)" % {'length': type_.length}
else:
return "RAW"
- def visit_ROWID(self, type_):
+ def visit_ROWID(self, type_, **kw):
return "ROWID"
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 0817fe837..89bea100e 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1425,7 +1425,8 @@ class PGDDLCompiler(compiler.DDLCompiler):
else:
colspec += " SERIAL"
else:
- colspec += " " + self.dialect.type_compiler.process(column.type)
+ colspec += " " + self.dialect.type_compiler.process(column.type,
+ type_expression=column)
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -1545,94 +1546,93 @@ class PGDDLCompiler(compiler.DDLCompiler):
class PGTypeCompiler(compiler.GenericTypeCompiler):
-
- def visit_TSVECTOR(self, type):
+ def visit_TSVECTOR(self, type, **kw):
return "TSVECTOR"
- def visit_INET(self, type_):
+ def visit_INET(self, type_, **kw):
return "INET"
- def visit_CIDR(self, type_):
+ def visit_CIDR(self, type_, **kw):
return "CIDR"
- def visit_MACADDR(self, type_):
+ def visit_MACADDR(self, type_, **kw):
return "MACADDR"
- def visit_OID(self, type_):
+ def visit_OID(self, type_, **kw):
return "OID"
- def visit_FLOAT(self, type_):
+ def visit_FLOAT(self, type_, **kw):
if not type_.precision:
return "FLOAT"
else:
return "FLOAT(%(precision)s)" % {'precision': type_.precision}
- def visit_DOUBLE_PRECISION(self, type_):
+ def visit_DOUBLE_PRECISION(self, type_, **kw):
return "DOUBLE PRECISION"
- def visit_BIGINT(self, type_):
+ def visit_BIGINT(self, type_, **kw):
return "BIGINT"
- def visit_HSTORE(self, type_):
+ def visit_HSTORE(self, type_, **kw):
return "HSTORE"
- def visit_JSON(self, type_):
+ def visit_JSON(self, type_, **kw):
return "JSON"
- def visit_JSONB(self, type_):
+ def visit_JSONB(self, type_, **kw):
return "JSONB"
- def visit_INT4RANGE(self, type_):
+ def visit_INT4RANGE(self, type_, **kw):
return "INT4RANGE"
- def visit_INT8RANGE(self, type_):
+ def visit_INT8RANGE(self, type_, **kw):
return "INT8RANGE"
- def visit_NUMRANGE(self, type_):
+ def visit_NUMRANGE(self, type_, **kw):
return "NUMRANGE"
- def visit_DATERANGE(self, type_):
+ def visit_DATERANGE(self, type_, **kw):
return "DATERANGE"
- def visit_TSRANGE(self, type_):
+ def visit_TSRANGE(self, type_, **kw):
return "TSRANGE"
- def visit_TSTZRANGE(self, type_):
+ def visit_TSTZRANGE(self, type_, **kw):
return "TSTZRANGE"
- def visit_datetime(self, type_):
- return self.visit_TIMESTAMP(type_)
+ def visit_datetime(self, type_, **kw):
+ return self.visit_TIMESTAMP(type_, **kw)
- def visit_enum(self, type_):
+ def visit_enum(self, type_, **kw):
if not type_.native_enum or not self.dialect.supports_native_enum:
- return super(PGTypeCompiler, self).visit_enum(type_)
+ return super(PGTypeCompiler, self).visit_enum(type_, **kw)
else:
- return self.visit_ENUM(type_)
+ return self.visit_ENUM(type_, **kw)
- def visit_ENUM(self, type_):
+ def visit_ENUM(self, type_, **kw):
return self.dialect.identifier_preparer.format_type(type_)
- def visit_TIMESTAMP(self, type_):
+ def visit_TIMESTAMP(self, type_, **kw):
return "TIMESTAMP%s %s" % (
getattr(type_, 'precision', None) and "(%d)" %
type_.precision or "",
(type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
)
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
return "TIME%s %s" % (
getattr(type_, 'precision', None) and "(%d)" %
type_.precision or "",
(type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
)
- def visit_INTERVAL(self, type_):
+ def visit_INTERVAL(self, type_, **kw):
if type_.precision is not None:
return "INTERVAL(%d)" % type_.precision
else:
return "INTERVAL"
- def visit_BIT(self, type_):
+ def visit_BIT(self, type_, **kw):
if type_.varying:
compiled = "BIT VARYING"
if type_.length is not None:
@@ -1641,16 +1641,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
compiled = "BIT(%d)" % type_.length
return compiled
- def visit_UUID(self, type_):
+ def visit_UUID(self, type_, **kw):
return "UUID"
- def visit_large_binary(self, type_):
- return self.visit_BYTEA(type_)
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BYTEA(type_, **kw)
- def visit_BYTEA(self, type_):
+ def visit_BYTEA(self, type_, **kw):
return "BYTEA"
- def visit_ARRAY(self, type_):
+ def visit_ARRAY(self, type_, **kw):
return self.process(type_.item_type) + ('[]' * (type_.dimensions
if type_.dimensions
is not None else 1))
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 3d7b0788b..f74421967 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -660,7 +660,8 @@ class SQLiteCompiler(compiler.SQLCompiler):
class SQLiteDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
- coltype = self.dialect.type_compiler.process(column.type)
+ coltype = self.dialect.type_compiler.process(
+ column.type, type_expression=column)
colspec = self.preparer.format_column(column) + " " + coltype
default = self.get_column_default_string(column)
if default is not None:
@@ -716,24 +717,24 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
- def visit_large_binary(self, type_):
+ def visit_large_binary(self, type_, **kw):
return self.visit_BLOB(type_)
- def visit_DATETIME(self, type_):
+ def visit_DATETIME(self, type_, **kw):
if not isinstance(type_, _DateTimeMixin) or \
type_.format_is_text_affinity:
return super(SQLiteTypeCompiler, self).visit_DATETIME(type_)
else:
return "DATETIME_CHAR"
- def visit_DATE(self, type_):
+ def visit_DATE(self, type_, **kw):
if not isinstance(type_, _DateTimeMixin) or \
type_.format_is_text_affinity:
return super(SQLiteTypeCompiler, self).visit_DATE(type_)
else:
return "DATE_CHAR"
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
if not isinstance(type_, _DateTimeMixin) or \
type_.format_is_text_affinity:
return super(SQLiteTypeCompiler, self).visit_TIME(type_)
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
index f65a76a27..369420358 100644
--- a/lib/sqlalchemy/dialects/sybase/base.py
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -146,40 +146,40 @@ class IMAGE(sqltypes.LargeBinary):
class SybaseTypeCompiler(compiler.GenericTypeCompiler):
- def visit_large_binary(self, type_):
+ def visit_large_binary(self, type_, **kw):
return self.visit_IMAGE(type_)
- def visit_boolean(self, type_):
+ def visit_boolean(self, type_, **kw):
return self.visit_BIT(type_)
- def visit_unicode(self, type_):
+ def visit_unicode(self, type_, **kw):
return self.visit_NVARCHAR(type_)
- def visit_UNICHAR(self, type_):
+ def visit_UNICHAR(self, type_, **kw):
return "UNICHAR(%d)" % type_.length
- def visit_UNIVARCHAR(self, type_):
+ def visit_UNIVARCHAR(self, type_, **kw):
return "UNIVARCHAR(%d)" % type_.length
- def visit_UNITEXT(self, type_):
+ def visit_UNITEXT(self, type_, **kw):
return "UNITEXT"
- def visit_TINYINT(self, type_):
+ def visit_TINYINT(self, type_, **kw):
return "TINYINT"
- def visit_IMAGE(self, type_):
+ def visit_IMAGE(self, type_, **kw):
return "IMAGE"
- def visit_BIT(self, type_):
+ def visit_BIT(self, type_, **kw):
return "BIT"
- def visit_MONEY(self, type_):
+ def visit_MONEY(self, type_, **kw):
return "MONEY"
- def visit_SMALLMONEY(self, type_):
+ def visit_SMALLMONEY(self, type_, **kw):
return "SMALLMONEY"
- def visit_UNIQUEIDENTIFIER(self, type_):
+ def visit_UNIQUEIDENTIFIER(self, type_, **kw):
return "UNIQUEIDENTIFIER"
ischema_names = {
@@ -377,7 +377,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
class SybaseDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(column.type)
+ self.dialect.type_compiler.process(
+ column.type, type_expression=column)
if column.table is None:
raise exc.CompileError(
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ca14c9371..da62b1434 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -248,15 +248,16 @@ class Compiled(object):
return self.execute(*multiparams, **params).scalar()
-class TypeCompiler(object):
-
+class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
+ ensure_kwarg = 'visit_\w+'
+
def __init__(self, dialect):
self.dialect = dialect
- def process(self, type_):
- return type_._compiler_dispatch(self)
+ def process(self, type_, **kw):
+ return type_._compiler_dispatch(self, **kw)
class _CompileLabel(visitors.Visitable):
@@ -638,8 +639,9 @@ class SQLCompiler(Compiled):
def visit_index(self, index, **kwargs):
return index.name
- def visit_typeclause(self, typeclause, **kwargs):
- return self.dialect.type_compiler.process(typeclause.type)
+ def visit_typeclause(self, typeclause, **kw):
+ kw['type_expression'] = typeclause
+ return self.dialect.type_compiler.process(typeclause.type, **kw)
def post_process_text(self, text):
return text
@@ -2259,7 +2261,8 @@ class DDLCompiler(Compiled):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(column.type)
+ self.dialect.type_compiler.process(
+ column.type, type_expression=column)
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -2383,13 +2386,13 @@ class DDLCompiler(Compiled):
class GenericTypeCompiler(TypeCompiler):
- def visit_FLOAT(self, type_):
+ def visit_FLOAT(self, type_, **kw):
return "FLOAT"
- def visit_REAL(self, type_):
+ def visit_REAL(self, type_, **kw):
return "REAL"
- def visit_NUMERIC(self, type_):
+ def visit_NUMERIC(self, type_, **kw):
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
@@ -2400,7 +2403,7 @@ class GenericTypeCompiler(TypeCompiler):
{'precision': type_.precision,
'scale': type_.scale}
- def visit_DECIMAL(self, type_):
+ def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
@@ -2411,31 +2414,31 @@ class GenericTypeCompiler(TypeCompiler):
{'precision': type_.precision,
'scale': type_.scale}
- def visit_INTEGER(self, type_):
+ def visit_INTEGER(self, type_, **kw):
return "INTEGER"
- def visit_SMALLINT(self, type_):
+ def visit_SMALLINT(self, type_, **kw):
return "SMALLINT"
- def visit_BIGINT(self, type_):
+ def visit_BIGINT(self, type_, **kw):
return "BIGINT"
- def visit_TIMESTAMP(self, type_):
+ def visit_TIMESTAMP(self, type_, **kw):
return 'TIMESTAMP'
- def visit_DATETIME(self, type_):
+ def visit_DATETIME(self, type_, **kw):
return "DATETIME"
- def visit_DATE(self, type_):
+ def visit_DATE(self, type_, **kw):
return "DATE"
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
return "TIME"
- def visit_CLOB(self, type_):
+ def visit_CLOB(self, type_, **kw):
return "CLOB"
- def visit_NCLOB(self, type_):
+ def visit_NCLOB(self, type_, **kw):
return "NCLOB"
def _render_string_type(self, type_, name):
@@ -2447,91 +2450,91 @@ class GenericTypeCompiler(TypeCompiler):
text += ' COLLATE "%s"' % type_.collation
return text
- def visit_CHAR(self, type_):
+ def visit_CHAR(self, type_, **kw):
return self._render_string_type(type_, "CHAR")
- def visit_NCHAR(self, type_):
+ def visit_NCHAR(self, type_, **kw):
return self._render_string_type(type_, "NCHAR")
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
return self._render_string_type(type_, "VARCHAR")
- def visit_NVARCHAR(self, type_):
+ def visit_NVARCHAR(self, type_, **kw):
return self._render_string_type(type_, "NVARCHAR")
- def visit_TEXT(self, type_):
+ def visit_TEXT(self, type_, **kw):
return self._render_string_type(type_, "TEXT")
- def visit_BLOB(self, type_):
+ def visit_BLOB(self, type_, **kw):
return "BLOB"
- def visit_BINARY(self, type_):
+ def visit_BINARY(self, type_, **kw):
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_VARBINARY(self, type_):
+ def visit_VARBINARY(self, type_, **kw):
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_BOOLEAN(self, type_):
+ def visit_BOOLEAN(self, type_, **kw):
return "BOOLEAN"
- def visit_large_binary(self, type_):
- return self.visit_BLOB(type_)
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_, **kw)
- def visit_boolean(self, type_):
- return self.visit_BOOLEAN(type_)
+ def visit_boolean(self, type_, **kw):
+ return self.visit_BOOLEAN(type_, **kw)
- def visit_time(self, type_):
- return self.visit_TIME(type_)
+ def visit_time(self, type_, **kw):
+ return self.visit_TIME(type_, **kw)
- def visit_datetime(self, type_):
- return self.visit_DATETIME(type_)
+ def visit_datetime(self, type_, **kw):
+ return self.visit_DATETIME(type_, **kw)
- def visit_date(self, type_):
- return self.visit_DATE(type_)
+ def visit_date(self, type_, **kw):
+ return self.visit_DATE(type_, **kw)
- def visit_big_integer(self, type_):
- return self.visit_BIGINT(type_)
+ def visit_big_integer(self, type_, **kw):
+ return self.visit_BIGINT(type_, **kw)
- def visit_small_integer(self, type_):
- return self.visit_SMALLINT(type_)
+ def visit_small_integer(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
- def visit_integer(self, type_):
- return self.visit_INTEGER(type_)
+ def visit_integer(self, type_, **kw):
+ return self.visit_INTEGER(type_, **kw)
- def visit_real(self, type_):
- return self.visit_REAL(type_)
+ def visit_real(self, type_, **kw):
+ return self.visit_REAL(type_, **kw)
- def visit_float(self, type_):
- return self.visit_FLOAT(type_)
+ def visit_float(self, type_, **kw):
+ return self.visit_FLOAT(type_, **kw)
- def visit_numeric(self, type_):
- return self.visit_NUMERIC(type_)
+ def visit_numeric(self, type_, **kw):
+ return self.visit_NUMERIC(type_, **kw)
- def visit_string(self, type_):
- return self.visit_VARCHAR(type_)
+ def visit_string(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
- def visit_unicode(self, type_):
- return self.visit_VARCHAR(type_)
+ def visit_unicode(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
- def visit_text(self, type_):
- return self.visit_TEXT(type_)
+ def visit_text(self, type_, **kw):
+ return self.visit_TEXT(type_, **kw)
- def visit_unicode_text(self, type_):
- return self.visit_TEXT(type_)
+ def visit_unicode_text(self, type_, **kw):
+ return self.visit_TEXT(type_, **kw)
- def visit_enum(self, type_):
- return self.visit_VARCHAR(type_)
+ def visit_enum(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
- def visit_null(self, type_):
+ def visit_null(self, type_, **kw):
raise exc.CompileError("Can't generate DDL for %r; "
"did you forget to specify a "
"type on this Column?" % type_)
- def visit_type_decorator(self, type_):
- return self.process(type_.type_engine(self.dialect))
+ def visit_type_decorator(self, type_, **kw):
+ return self.process(type_.type_engine(self.dialect), **kw)
- def visit_user_defined(self, type_):
- return type_.get_col_spec()
+ def visit_user_defined(self, type_, **kw):
+ return type_.get_col_spec(**kw)
class IdentifierPreparer(object):
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index bff497800..19398ae96 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -12,7 +12,7 @@
from .. import exc, util
from . import operators
-from .visitors import Visitable
+from .visitors import Visitable, VisitableType
# these are back-assigned by sqltypes.
BOOLEANTYPE = None
@@ -460,7 +460,11 @@ class TypeEngine(Visitable):
return util.generic_repr(self)
-class UserDefinedType(TypeEngine):
+class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType):
+ pass
+
+
+class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)):
"""Base for user defined types.
This should be the base of new types. Note that
@@ -473,7 +477,7 @@ class UserDefinedType(TypeEngine):
def __init__(self, precision = 8):
self.precision = precision
- def get_col_spec(self):
+ def get_col_spec(self, **kw):
return "MYTYPE(%s)" % self.precision
def bind_processor(self, dialect):
@@ -493,9 +497,23 @@ class UserDefinedType(TypeEngine):
Column('data', MyType(16))
)
+ The ``get_col_spec()`` method will in most cases receive a keyword
+ argument ``type_expression`` which refers to the owning expression
+ of the type as being compiled, such as a :class:`.Column` or
+ :func:`.cast` construct. This keyword is only sent if the method
+ accepts keyword arguments (e.g. ``**kw``) in its argument signature;
+ introspection is used to check for this in order to support legacy
+ forms of this function.
+
+ .. versionadded:: 1.0.0 the owning expression is passed to
+ the ``get_col_spec()`` method via the keyword argument
+ ``type_expression``, if it receives ``**kw`` in its signature.
+
"""
__visit_name__ = "user_defined"
+ ensure_kwarg = 'get_col_spec'
+
class Comparator(TypeEngine.Comparator):
__slots__ = ()
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index c23b0196f..ceee18d86 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -36,7 +36,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \
generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \
safe_reraise,\
get_callable_argspec, only_once, attrsetter, ellipses_string, \
- warn_limited, map_bits, MemoizedSlots
+ warn_limited, map_bits, MemoizedSlots, EnsureKWArgType
from .deprecations import warn_deprecated, warn_pending_deprecation, \
deprecated, pending_deprecation, inject_docstring_text
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 22b6ad4ca..5a938501a 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -1348,6 +1348,7 @@ def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
NoneType = type(None)
+
def attrsetter(attrname):
code = \
"def set(obj, value):"\
@@ -1355,3 +1356,29 @@ def attrsetter(attrname):
env = locals().copy()
exec(code, env)
return env['set']
+
+
+class EnsureKWArgType(type):
+ """Apply translation of functions to accept **kw arguments if they
+ don't already.
+
+ """
+ def __init__(cls, clsname, bases, clsdict):
+ fn_reg = cls.ensure_kwarg
+ if fn_reg:
+ for key in clsdict:
+ m = re.match(fn_reg, key)
+ if m:
+ fn = clsdict[key]
+ spec = inspect.getargspec(fn)
+ if not spec.keywords:
+ clsdict[key] = wrapped = cls._wrap_w_kw(fn)
+ setattr(cls, key, wrapped)
+ super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
+
+ def _wrap_w_kw(self, fn):
+
+ def wrap(*arg, **kw):
+ return fn(*arg)
+ return update_wrapper(wrap, fn)
+
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 6ffd88d78..38b3ced13 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -10,6 +10,8 @@ from sqlalchemy import (
type_coerce, VARCHAR, Time, DateTime, BigInteger, SmallInteger, BOOLEAN,
BLOB, NCHAR, NVARCHAR, CLOB, TIME, DATE, DATETIME, TIMESTAMP, SMALLINT,
INTEGER, DECIMAL, NUMERIC, FLOAT, REAL)
+from sqlalchemy.sql import ddl
+
from sqlalchemy import exc, types, util, dialects
for name in dialects.__all__:
__import__("sqlalchemy.dialects.%s" % name)
@@ -309,6 +311,24 @@ class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL):
literal_binds=True
)
+ def test_kw_colspec(self):
+ class MyType(types.UserDefinedType):
+ def get_col_spec(self, **kw):
+ return "FOOB %s" % kw['type_expression'].name
+
+ class MyOtherType(types.UserDefinedType):
+ def get_col_spec(self):
+ return "BAR"
+
+ self.assert_compile(
+ ddl.CreateColumn(Column('bar', MyType)),
+ "bar FOOB bar"
+ )
+ self.assert_compile(
+ ddl.CreateColumn(Column('bar', MyOtherType)),
+ "bar BAR"
+ )
+
def test_typedecorator_literal_render_fallback_bound(self):
# fall back to process_bind_param for literal
# value rendering.
@@ -1642,6 +1662,49 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
def test_decimal_scale(self):
self.assert_compile(types.DECIMAL(2, 4), 'DECIMAL(2, 4)')
+ def test_kwarg_legacy_typecompiler(self):
+ from sqlalchemy.sql import compiler
+
+ class SomeTypeCompiler(compiler.GenericTypeCompiler):
+ # transparently decorated w/ kw decorator
+ def visit_VARCHAR(self, type_):
+ return "MYVARCHAR"
+
+ # not affected
+ def visit_INTEGER(self, type_, **kw):
+ return "MYINTEGER %s" % kw['type_expression'].name
+
+ dialect = default.DefaultDialect()
+ dialect.type_compiler = SomeTypeCompiler(dialect)
+ self.assert_compile(
+ ddl.CreateColumn(Column('bar', VARCHAR(50))),
+ "bar MYVARCHAR",
+ dialect=dialect
+ )
+ self.assert_compile(
+ ddl.CreateColumn(Column('bar', INTEGER)),
+ "bar MYINTEGER bar",
+ dialect=dialect
+ )
+
+
+class TestKWArgPassThru(AssertsCompiledSQL, fixtures.TestBase):
+ __backend__ = True
+
+ def test_user_defined(self):
+ """test that dialects pass the column through on DDL."""
+
+ class MyType(types.UserDefinedType):
+ def get_col_spec(self, **kw):
+ return "FOOB %s" % kw['type_expression'].name
+
+ m = MetaData()
+ t = Table('t', m, Column('bar', MyType))
+ self.assert_compile(
+ ddl.CreateColumn(t.c.bar),
+ "bar FOOB bar"
+ )
+
class NumericRawSQLTest(fixtures.TestBase):