summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mysql/base.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-01-16 20:03:33 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2015-01-16 20:03:33 -0500
commitf3a892a3ef666e299107a990bf4eae7ed9a953ae (patch)
tree01c0bbb71be7b397fd2f91b406c3ae7889b2306d /lib/sqlalchemy/dialects/mysql/base.py
parent79fa69f1f37fdbc0dfec6bdea1e07f52bfe18f7b (diff)
downloadsqlalchemy-f3a892a3ef666e299107a990bf4eae7ed9a953ae.tar.gz
- Custom dialects that implement :class:`.GenericTypeCompiler` can
now be constructed such that the visit methods receive an indication of the owning expression object, if any. Any visit method that accepts keyword arguments (e.g. ``**kw``) will in most cases receive a keyword argument ``type_expression``, referring to the expression object that the type is contained within. For columns in DDL, the dialect's compiler class may need to alter its ``get_column_specification()`` method to support this as well. The ``UserDefinedType.get_col_spec()`` method will also receive ``type_expression`` if it provides ``**kw`` in its argument signature. fixes #3074
Diffstat (limited to 'lib/sqlalchemy/dialects/mysql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py76
1 files changed, 39 insertions, 37 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 9c3f23cb2..ca56a4d23 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1859,9 +1859,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kw):
"""Builds column DDL."""
- colspec = [self.preparer.format_column(column),
- self.dialect.type_compiler.process(column.type)
- ]
+ colspec = [
+ self.preparer.format_column(column),
+ self.dialect.type_compiler.process(
+ column.type, type_expression=column)
+ ]
default = self.get_column_default_string(column)
if default is not None:
@@ -2059,7 +2061,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def _mysql_type(self, type_):
return isinstance(type_, (_StringType, _NumericType))
- def visit_NUMERIC(self, type_):
+ def visit_NUMERIC(self, type_, **kw):
if type_.precision is None:
return self._extend_numeric(type_, "NUMERIC")
elif type_.scale is None:
@@ -2072,7 +2074,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
{'precision': type_.precision,
'scale': type_.scale})
- def visit_DECIMAL(self, type_):
+ def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return self._extend_numeric(type_, "DECIMAL")
elif type_.scale is None:
@@ -2085,7 +2087,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
{'precision': type_.precision,
'scale': type_.scale})
- def visit_DOUBLE(self, type_):
+ def visit_DOUBLE(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
return self._extend_numeric(type_,
"DOUBLE(%(precision)s, %(scale)s)" %
@@ -2094,7 +2096,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, 'DOUBLE')
- def visit_REAL(self, type_):
+ def visit_REAL(self, type_, **kw):
if type_.precision is not None and type_.scale is not None:
return self._extend_numeric(type_,
"REAL(%(precision)s, %(scale)s)" %
@@ -2103,7 +2105,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, 'REAL')
- def visit_FLOAT(self, type_):
+ def visit_FLOAT(self, type_, **kw):
if self._mysql_type(type_) and \
type_.scale is not None and \
type_.precision is not None:
@@ -2115,7 +2117,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "FLOAT")
- def visit_INTEGER(self, type_):
+ def visit_INTEGER(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "INTEGER(%(display_width)s)" %
@@ -2123,7 +2125,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "INTEGER")
- def visit_BIGINT(self, type_):
+ def visit_BIGINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "BIGINT(%(display_width)s)" %
@@ -2131,7 +2133,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "BIGINT")
- def visit_MEDIUMINT(self, type_):
+ def visit_MEDIUMINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "MEDIUMINT(%(display_width)s)" %
@@ -2139,14 +2141,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "MEDIUMINT")
- def visit_TINYINT(self, type_):
+ def visit_TINYINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(type_,
"TINYINT(%s)" % type_.display_width)
else:
return self._extend_numeric(type_, "TINYINT")
- def visit_SMALLINT(self, type_):
+ def visit_SMALLINT(self, type_, **kw):
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(type_,
"SMALLINT(%(display_width)s)" %
@@ -2155,55 +2157,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_numeric(type_, "SMALLINT")
- def visit_BIT(self, type_):
+ def visit_BIT(self, type_, **kw):
if type_.length is not None:
return "BIT(%s)" % type_.length
else:
return "BIT"
- def visit_DATETIME(self, type_):
+ def visit_DATETIME(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "DATETIME(%d)" % type_.fsp
else:
return "DATETIME"
- def visit_DATE(self, type_):
+ def visit_DATE(self, type_, **kw):
return "DATE"
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "TIME(%d)" % type_.fsp
else:
return "TIME"
- def visit_TIMESTAMP(self, type_):
+ def visit_TIMESTAMP(self, type_, **kw):
if getattr(type_, 'fsp', None):
return "TIMESTAMP(%d)" % type_.fsp
else:
return "TIMESTAMP"
- def visit_YEAR(self, type_):
+ def visit_YEAR(self, type_, **kw):
if type_.display_width is None:
return "YEAR"
else:
return "YEAR(%s)" % type_.display_width
- def visit_TEXT(self, type_):
+ def visit_TEXT(self, type_, **kw):
if type_.length:
return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
else:
return self._extend_string(type_, {}, "TEXT")
- def visit_TINYTEXT(self, type_):
+ def visit_TINYTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "TINYTEXT")
- def visit_MEDIUMTEXT(self, type_):
+ def visit_MEDIUMTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "MEDIUMTEXT")
- def visit_LONGTEXT(self, type_):
+ def visit_LONGTEXT(self, type_, **kw):
return self._extend_string(type_, {}, "LONGTEXT")
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
if type_.length:
return self._extend_string(
type_, {}, "VARCHAR(%d)" % type_.length)
@@ -2212,14 +2214,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
"VARCHAR requires a length on dialect %s" %
self.dialect.name)
- def visit_CHAR(self, type_):
+ def visit_CHAR(self, type_, **kw):
if type_.length:
return self._extend_string(type_, {}, "CHAR(%(length)s)" %
{'length': type_.length})
else:
return self._extend_string(type_, {}, "CHAR")
- def visit_NVARCHAR(self, type_):
+ def visit_NVARCHAR(self, type_, **kw):
# We'll actually generate the equiv. "NATIONAL VARCHAR" instead
# of "NVARCHAR".
if type_.length:
@@ -2231,7 +2233,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
"NVARCHAR requires a length on dialect %s" %
self.dialect.name)
- def visit_NCHAR(self, type_):
+ def visit_NCHAR(self, type_, **kw):
# We'll actually generate the equiv.
# "NATIONAL CHAR" instead of "NCHAR".
if type_.length:
@@ -2241,31 +2243,31 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
else:
return self._extend_string(type_, {'national': True}, "CHAR")
- def visit_VARBINARY(self, type_):
+ def visit_VARBINARY(self, type_, **kw):
return "VARBINARY(%d)" % type_.length
- def visit_large_binary(self, type_):
+ def visit_large_binary(self, type_, **kw):
return self.visit_BLOB(type_)
- def visit_enum(self, type_):
+ def visit_enum(self, type_, **kw):
if not type_.native_enum:
return super(MySQLTypeCompiler, self).visit_enum(type_)
else:
return self._visit_enumerated_values("ENUM", type_, type_.enums)
- def visit_BLOB(self, type_):
+ def visit_BLOB(self, type_, **kw):
if type_.length:
return "BLOB(%d)" % type_.length
else:
return "BLOB"
- def visit_TINYBLOB(self, type_):
+ def visit_TINYBLOB(self, type_, **kw):
return "TINYBLOB"
- def visit_MEDIUMBLOB(self, type_):
+ def visit_MEDIUMBLOB(self, type_, **kw):
return "MEDIUMBLOB"
- def visit_LONGBLOB(self, type_):
+ def visit_LONGBLOB(self, type_, **kw):
return "LONGBLOB"
def _visit_enumerated_values(self, name, type_, enumerated_values):
@@ -2276,15 +2278,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
name, ",".join(quoted_enums))
)
- def visit_ENUM(self, type_):
+ def visit_ENUM(self, type_, **kw):
return self._visit_enumerated_values("ENUM", type_,
type_._enumerated_values)
- def visit_SET(self, type_):
+ def visit_SET(self, type_, **kw):
return self._visit_enumerated_values("SET", type_,
type_._enumerated_values)
- def visit_BOOLEAN(self, type):
+ def visit_BOOLEAN(self, type, **kw):
return "BOOL"