diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 135 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 24 |
2 files changed, 90 insertions, 69 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ca14c9371..da62b1434 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -248,15 +248,16 @@ class Compiled(object): return self.execute(*multiparams, **params).scalar() -class TypeCompiler(object): - +class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" + ensure_kwarg = 'visit_\w+' + def __init__(self, dialect): self.dialect = dialect - def process(self, type_): - return type_._compiler_dispatch(self) + def process(self, type_, **kw): + return type_._compiler_dispatch(self, **kw) class _CompileLabel(visitors.Visitable): @@ -638,8 +639,9 @@ class SQLCompiler(Compiled): def visit_index(self, index, **kwargs): return index.name - def visit_typeclause(self, typeclause, **kwargs): - return self.dialect.type_compiler.process(typeclause.type) + def visit_typeclause(self, typeclause, **kw): + kw['type_expression'] = typeclause + return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): return text @@ -2259,7 +2261,8 @@ class DDLCompiler(Compiled): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2383,13 +2386,13 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_): + def visit_FLOAT(self, type_, **kw): return "FLOAT" - def visit_REAL(self, type_): + def visit_REAL(self, type_, **kw): return "REAL" - def visit_NUMERIC(self, type_): + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -2400,7 +2403,7 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_DECIMAL(self, type_): + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: @@ -2411,31 +2414,31 @@ class GenericTypeCompiler(TypeCompiler): {'precision': type_.precision, 'scale': type_.scale} - def visit_INTEGER(self, type_): + def visit_INTEGER(self, type_, **kw): return "INTEGER" - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): return "SMALLINT" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return 'TIMESTAMP' - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME" - def visit_CLOB(self, type_): + def visit_CLOB(self, type_, **kw): return "CLOB" - def visit_NCLOB(self, type_): + def visit_NCLOB(self, type_, **kw): return "NCLOB" def _render_string_type(self, type_, name): @@ -2447,91 +2450,91 @@ class GenericTypeCompiler(TypeCompiler): text += ' COLLATE "%s"' % type_.collation return text - def visit_CHAR(self, type_): + def visit_CHAR(self, type_, **kw): return self._render_string_type(type_, "CHAR") - def visit_NCHAR(self, type_): + def visit_NCHAR(self, type_, **kw): return self._render_string_type(type_, "NCHAR") - def visit_VARCHAR(self, type_): + def visit_VARCHAR(self, type_, **kw): return self._render_string_type(type_, "VARCHAR") - def visit_NVARCHAR(self, type_): + def visit_NVARCHAR(self, type_, **kw): return self._render_string_type(type_, "NVARCHAR") - def visit_TEXT(self, type_): + def visit_TEXT(self, type_, **kw): return self._render_string_type(type_, "TEXT") - def visit_BLOB(self, type_): + def visit_BLOB(self, type_, **kw): return "BLOB" - def visit_BINARY(self, type_): + def visit_BINARY(self, type_, **kw): return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_): + def visit_BOOLEAN(self, type_, **kw): return "BOOLEAN" - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_): - return self.visit_BOOLEAN(type_) + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_): - return self.visit_TIME(type_) + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_): - return self.visit_DATETIME(type_) + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_): - return self.visit_DATE(type_) + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) - def visit_big_integer(self, type_): - return self.visit_BIGINT(type_) + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) - def visit_small_integer(self, type_): - return self.visit_SMALLINT(type_) + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_): - return self.visit_INTEGER(type_) + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) - def visit_real(self, type_): - return self.visit_REAL(type_) + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) - def visit_float(self, type_): - return self.visit_FLOAT(type_) + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) - def visit_numeric(self, type_): - return self.visit_NUMERIC(type_) + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_): - return self.visit_VARCHAR(type_) + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_): - return self.visit_VARCHAR(type_) + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_): - return self.visit_TEXT(type_) + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_): - return self.visit_TEXT(type_) + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_): - return self.visit_VARCHAR(type_) + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) - def visit_null(self, type_): + def visit_null(self, type_, **kw): raise exc.CompileError("Can't generate DDL for %r; " "did you forget to specify a " "type on this Column?" % type_) - def visit_type_decorator(self, type_): - return self.process(type_.type_engine(self.dialect)) + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) - def visit_user_defined(self, type_): - return type_.get_col_spec() + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) class IdentifierPreparer(object): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index bff497800..19398ae96 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -12,7 +12,7 @@ from .. import exc, util from . import operators -from .visitors import Visitable +from .visitors import Visitable, VisitableType # these are back-assigned by sqltypes. BOOLEANTYPE = None @@ -460,7 +460,11 @@ class TypeEngine(Visitable): return util.generic_repr(self) -class UserDefinedType(TypeEngine): +class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType): + pass + + +class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)): """Base for user defined types. This should be the base of new types. Note that @@ -473,7 +477,7 @@ class UserDefinedType(TypeEngine): def __init__(self, precision = 8): self.precision = precision - def get_col_spec(self): + def get_col_spec(self, **kw): return "MYTYPE(%s)" % self.precision def bind_processor(self, dialect): @@ -493,9 +497,23 @@ class UserDefinedType(TypeEngine): Column('data', MyType(16)) ) + The ``get_col_spec()`` method will in most cases receive a keyword + argument ``type_expression`` which refers to the owning expression + of the type as being compiled, such as a :class:`.Column` or + :func:`.cast` construct. This keyword is only sent if the method + accepts keyword arguments (e.g. ``**kw``) in its argument signature; + introspection is used to check for this in order to support legacy + forms of this function. + + .. versionadded:: 1.0.0 the owning expression is passed to + the ``get_col_spec()`` method via the keyword argument + ``type_expression``, if it receives ``**kw`` in its signature. + """ __visit_name__ = "user_defined" + ensure_kwarg = 'get_col_spec' + class Comparator(TypeEngine.Comparator): __slots__ = () |