summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-01-16 20:03:33 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2015-01-16 20:03:33 -0500
commitf3a892a3ef666e299107a990bf4eae7ed9a953ae (patch)
tree01c0bbb71be7b397fd2f91b406c3ae7889b2306d /lib/sqlalchemy/sql/compiler.py
parent79fa69f1f37fdbc0dfec6bdea1e07f52bfe18f7b (diff)
downloadsqlalchemy-f3a892a3ef666e299107a990bf4eae7ed9a953ae.tar.gz
- Custom dialects that implement :class:`.GenericTypeCompiler` can
now be constructed such that the visit methods receive an indication of the owning expression object, if any. Any visit method that accepts keyword arguments (e.g. ``**kw``) will in most cases receive a keyword argument ``type_expression``, referring to the expression object that the type is contained within. For columns in DDL, the dialect's compiler class may need to alter its ``get_column_specification()`` method to support this as well. The ``UserDefinedType.get_col_spec()`` method will also receive ``type_expression`` if it provides ``**kw`` in its argument signature. fixes #3074
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py135
1 files changed, 69 insertions, 66 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ca14c9371..da62b1434 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -248,15 +248,16 @@ class Compiled(object):
return self.execute(*multiparams, **params).scalar()
-class TypeCompiler(object):
-
+class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
+ ensure_kwarg = 'visit_\w+'
+
def __init__(self, dialect):
self.dialect = dialect
- def process(self, type_):
- return type_._compiler_dispatch(self)
+ def process(self, type_, **kw):
+ return type_._compiler_dispatch(self, **kw)
class _CompileLabel(visitors.Visitable):
@@ -638,8 +639,9 @@ class SQLCompiler(Compiled):
def visit_index(self, index, **kwargs):
return index.name
- def visit_typeclause(self, typeclause, **kwargs):
- return self.dialect.type_compiler.process(typeclause.type)
+ def visit_typeclause(self, typeclause, **kw):
+ kw['type_expression'] = typeclause
+ return self.dialect.type_compiler.process(typeclause.type, **kw)
def post_process_text(self, text):
return text
@@ -2259,7 +2261,8 @@ class DDLCompiler(Compiled):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(column.type)
+ self.dialect.type_compiler.process(
+ column.type, type_expression=column)
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -2383,13 +2386,13 @@ class DDLCompiler(Compiled):
class GenericTypeCompiler(TypeCompiler):
- def visit_FLOAT(self, type_):
+ def visit_FLOAT(self, type_, **kw):
return "FLOAT"
- def visit_REAL(self, type_):
+ def visit_REAL(self, type_, **kw):
return "REAL"
- def visit_NUMERIC(self, type_):
+ def visit_NUMERIC(self, type_, **kw):
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
@@ -2400,7 +2403,7 @@ class GenericTypeCompiler(TypeCompiler):
{'precision': type_.precision,
'scale': type_.scale}
- def visit_DECIMAL(self, type_):
+ def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
@@ -2411,31 +2414,31 @@ class GenericTypeCompiler(TypeCompiler):
{'precision': type_.precision,
'scale': type_.scale}
- def visit_INTEGER(self, type_):
+ def visit_INTEGER(self, type_, **kw):
return "INTEGER"
- def visit_SMALLINT(self, type_):
+ def visit_SMALLINT(self, type_, **kw):
return "SMALLINT"
- def visit_BIGINT(self, type_):
+ def visit_BIGINT(self, type_, **kw):
return "BIGINT"
- def visit_TIMESTAMP(self, type_):
+ def visit_TIMESTAMP(self, type_, **kw):
return 'TIMESTAMP'
- def visit_DATETIME(self, type_):
+ def visit_DATETIME(self, type_, **kw):
return "DATETIME"
- def visit_DATE(self, type_):
+ def visit_DATE(self, type_, **kw):
return "DATE"
- def visit_TIME(self, type_):
+ def visit_TIME(self, type_, **kw):
return "TIME"
- def visit_CLOB(self, type_):
+ def visit_CLOB(self, type_, **kw):
return "CLOB"
- def visit_NCLOB(self, type_):
+ def visit_NCLOB(self, type_, **kw):
return "NCLOB"
def _render_string_type(self, type_, name):
@@ -2447,91 +2450,91 @@ class GenericTypeCompiler(TypeCompiler):
text += ' COLLATE "%s"' % type_.collation
return text
- def visit_CHAR(self, type_):
+ def visit_CHAR(self, type_, **kw):
return self._render_string_type(type_, "CHAR")
- def visit_NCHAR(self, type_):
+ def visit_NCHAR(self, type_, **kw):
return self._render_string_type(type_, "NCHAR")
- def visit_VARCHAR(self, type_):
+ def visit_VARCHAR(self, type_, **kw):
return self._render_string_type(type_, "VARCHAR")
- def visit_NVARCHAR(self, type_):
+ def visit_NVARCHAR(self, type_, **kw):
return self._render_string_type(type_, "NVARCHAR")
- def visit_TEXT(self, type_):
+ def visit_TEXT(self, type_, **kw):
return self._render_string_type(type_, "TEXT")
- def visit_BLOB(self, type_):
+ def visit_BLOB(self, type_, **kw):
return "BLOB"
- def visit_BINARY(self, type_):
+ def visit_BINARY(self, type_, **kw):
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_VARBINARY(self, type_):
+ def visit_VARBINARY(self, type_, **kw):
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_BOOLEAN(self, type_):
+ def visit_BOOLEAN(self, type_, **kw):
return "BOOLEAN"
- def visit_large_binary(self, type_):
- return self.visit_BLOB(type_)
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_, **kw)
- def visit_boolean(self, type_):
- return self.visit_BOOLEAN(type_)
+ def visit_boolean(self, type_, **kw):
+ return self.visit_BOOLEAN(type_, **kw)
- def visit_time(self, type_):
- return self.visit_TIME(type_)
+ def visit_time(self, type_, **kw):
+ return self.visit_TIME(type_, **kw)
- def visit_datetime(self, type_):
- return self.visit_DATETIME(type_)
+ def visit_datetime(self, type_, **kw):
+ return self.visit_DATETIME(type_, **kw)
- def visit_date(self, type_):
- return self.visit_DATE(type_)
+ def visit_date(self, type_, **kw):
+ return self.visit_DATE(type_, **kw)
- def visit_big_integer(self, type_):
- return self.visit_BIGINT(type_)
+ def visit_big_integer(self, type_, **kw):
+ return self.visit_BIGINT(type_, **kw)
- def visit_small_integer(self, type_):
- return self.visit_SMALLINT(type_)
+ def visit_small_integer(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
- def visit_integer(self, type_):
- return self.visit_INTEGER(type_)
+ def visit_integer(self, type_, **kw):
+ return self.visit_INTEGER(type_, **kw)
- def visit_real(self, type_):
- return self.visit_REAL(type_)
+ def visit_real(self, type_, **kw):
+ return self.visit_REAL(type_, **kw)
- def visit_float(self, type_):
- return self.visit_FLOAT(type_)
+ def visit_float(self, type_, **kw):
+ return self.visit_FLOAT(type_, **kw)
- def visit_numeric(self, type_):
- return self.visit_NUMERIC(type_)
+ def visit_numeric(self, type_, **kw):
+ return self.visit_NUMERIC(type_, **kw)
- def visit_string(self, type_):
- return self.visit_VARCHAR(type_)
+ def visit_string(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
- def visit_unicode(self, type_):
- return self.visit_VARCHAR(type_)
+ def visit_unicode(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
- def visit_text(self, type_):
- return self.visit_TEXT(type_)
+ def visit_text(self, type_, **kw):
+ return self.visit_TEXT(type_, **kw)
- def visit_unicode_text(self, type_):
- return self.visit_TEXT(type_)
+ def visit_unicode_text(self, type_, **kw):
+ return self.visit_TEXT(type_, **kw)
- def visit_enum(self, type_):
- return self.visit_VARCHAR(type_)
+ def visit_enum(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
- def visit_null(self, type_):
+ def visit_null(self, type_, **kw):
raise exc.CompileError("Can't generate DDL for %r; "
"did you forget to specify a "
"type on this Column?" % type_)
- def visit_type_decorator(self, type_):
- return self.process(type_.type_engine(self.dialect))
+ def visit_type_decorator(self, type_, **kw):
+ return self.process(type_.type_engine(self.dialect), **kw)
- def visit_user_defined(self, type_):
- return type_.get_col_spec()
+ def visit_user_defined(self, type_, **kw):
+ return type_.get_col_spec(**kw)
class IdentifierPreparer(object):