summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py872
1 files changed, 538 insertions, 334 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 6af65ec14..6bfad4a76 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -6,19 +6,23 @@
"""Base SQL and DDL compiler implementations.
-Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is
-responsible for generating all SQL query strings, as well as
-:class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper`
-which issue CREATE and DROP DDL for tables, sequences, and indexes.
-
-The elements in this module are used by public-facing constructs like
-:class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`.
-While dialect authors will want to be familiar with this module for the purpose of
-creating database-specific compilers and schema generators, the module
-is otherwise internal to SQLAlchemy.
+Classes provided include:
+
+:class:`~sqlalchemy.sql.compiler.SQLCompiler` - renders SQL
+strings
+
+:class:`~sqlalchemy.sql.compiler.DDLCompiler` - renders DDL
+(data definition language) strings
+
+:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders
+type specification strings.
+
+To generate user-defined SQL strings, see
+:module:`~sqlalchemy.ext.compiler`.
+
"""
-import string, re
+import re
from sqlalchemy import schema, engine, util, exc
from sqlalchemy.sql import operators, functions, util as sql_util, visitors
from sqlalchemy.sql import expression as sql
@@ -58,40 +62,43 @@ BIND_TEMPLATES = {
OPERATORS = {
- operators.and_ : 'AND',
- operators.or_ : 'OR',
- operators.inv : 'NOT',
- operators.add : '+',
- operators.mul : '*',
- operators.sub : '-',
- operators.div : '/',
- operators.mod : '%',
- operators.truediv : '/',
- operators.lt : '<',
- operators.le : '<=',
- operators.ne : '!=',
- operators.gt : '>',
- operators.ge : '>=',
- operators.eq : '=',
- operators.distinct_op : 'DISTINCT',
- operators.concat_op : '||',
- operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.between_op : 'BETWEEN',
- operators.match_op : 'MATCH',
- operators.in_op : 'IN',
- operators.notin_op : 'NOT IN',
+ # binary
+ operators.and_ : ' AND ',
+ operators.or_ : ' OR ',
+ operators.add : ' + ',
+ operators.mul : ' * ',
+ operators.sub : ' - ',
+ # Py2K
+ operators.div : ' / ',
+ # end Py2K
+ operators.mod : ' % ',
+ operators.truediv : ' / ',
+ operators.lt : ' < ',
+ operators.le : ' <= ',
+ operators.ne : ' != ',
+ operators.gt : ' > ',
+ operators.ge : ' >= ',
+ operators.eq : ' = ',
+ operators.concat_op : ' || ',
+ operators.between_op : ' BETWEEN ',
+ operators.match_op : ' MATCH ',
+ operators.in_op : ' IN ',
+ operators.notin_op : ' NOT IN ',
operators.comma_op : ', ',
- operators.desc_op : 'DESC',
- operators.asc_op : 'ASC',
- operators.from_ : 'FROM',
- operators.as_ : 'AS',
- operators.exists : 'EXISTS',
- operators.is_ : 'IS',
- operators.isnot : 'IS NOT',
- operators.collate : 'COLLATE',
+ operators.from_ : ' FROM ',
+ operators.as_ : ' AS ',
+ operators.is_ : ' IS ',
+ operators.isnot : ' IS NOT ',
+ operators.collate : ' COLLATE ',
+
+ # unary
+ operators.exists : 'EXISTS ',
+ operators.distinct_op : 'DISTINCT ',
+ operators.inv : 'NOT ',
+
+ # modifiers
+ operators.desc_op : ' DESC',
+ operators.asc_op : ' ASC',
}
FUNCTIONS = {
@@ -140,7 +147,7 @@ class _CompileLabel(visitors.Visitable):
def quote(self):
return self.element.quote
-class DefaultCompiler(engine.Compiled):
+class SQLCompiler(engine.Compiled):
"""Default implementation of Compiled.
Compiles ClauseElements into SQL strings. Uses a similar visit
@@ -148,14 +155,14 @@ class DefaultCompiler(engine.Compiled):
"""
- operators = OPERATORS
- functions = FUNCTIONS
extract_map = EXTRACT_MAP
- # if we are insert/update/delete.
- # set to true when we visit an INSERT, UPDATE or DELETE
+ # class-level defaults which can be set at the instance
+ # level to define if this Compiled instance represents
+ # INSERT/UPDATE/DELETE
isdelete = isinsert = isupdate = False
-
+ returning = None
+
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -170,7 +177,9 @@ class DefaultCompiler(engine.Compiled):
statement.
"""
- engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs)
+ engine.Compiled.__init__(self, dialect, statement, **kwargs)
+
+ self.column_keys = column_keys
# compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
self.inline = inline or getattr(statement, 'inline', False)
@@ -210,12 +219,6 @@ class DefaultCompiler(engine.Compiled):
# or dialect.max_identifier_length
self.truncated_names = {}
- def compile(self):
- self.string = self.process(self.statement)
-
- def process(self, obj, **kwargs):
- return obj._compiler_dispatch(self, **kwargs)
-
def is_subquery(self):
return len(self.stack) > 1
@@ -223,7 +226,6 @@ class DefaultCompiler(engine.Compiled):
"""return a dictionary of bind parameter keys and values"""
if params:
- params = util.column_dict(params)
pd = {}
for bindparam, name in self.bind_names.iteritems():
for paramname in (bindparam.key, bindparam.shortname, name):
@@ -245,7 +247,10 @@ class DefaultCompiler(engine.Compiled):
pd[self.bind_names[bindparam]] = bindparam.value
return pd
- params = property(construct_params)
+ params = property(construct_params, doc="""
+ Return the bind params for this compiled object.
+
+ """)
def default_from(self):
"""Called when a SELECT statement has no froms, and no FROM clause is to be appended.
@@ -267,10 +272,11 @@ class DefaultCompiler(engine.Compiled):
self._truncated_identifier("colident", label.name) or label.name
if result_map is not None:
- result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
+ result_map[labelname.lower()] = \
+ (label.name, (label, label.element, labelname), label.element.type)
- return self.process(label.element) + " " + \
- self.operator_string(operators.as_) + " " + \
+ return self.process(label.element) + \
+ OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
else:
return self.process(label.element)
@@ -292,14 +298,17 @@ class DefaultCompiler(engine.Compiled):
return name
else:
if column.table.schema:
- schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.'
+ schema_prefix = self.preparer.quote_schema(
+ column.table.schema,
+ column.table.quote_schema) + '.'
else:
schema_prefix = ''
tablename = column.table.name
tablename = isinstance(tablename, sql._generated_label) and \
self._truncated_identifier("alias", tablename) or tablename
- return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name
+ return schema_prefix + \
+ self.preparer.quote(tablename, column.table.quote) + "." + name
def escape_literal_column(self, text):
"""provide escaping for the literal_column() construct."""
@@ -314,7 +323,7 @@ class DefaultCompiler(engine.Compiled):
return index.name
def visit_typeclause(self, typeclause, **kwargs):
- return typeclause.type.dialect_impl(self.dialect).get_col_spec()
+ return self.dialect.type_compiler.process(typeclause.type)
def post_process_text(self, text):
return text
@@ -343,10 +352,8 @@ class DefaultCompiler(engine.Compiled):
sep = clauselist.operator
if sep is None:
sep = " "
- elif sep is operators.comma_op:
- sep = ', '
else:
- sep = " " + self.operator_string(clauselist.operator) + " "
+ sep = OPERATORS[clauselist.operator]
return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
if s is not None)
@@ -362,7 +369,8 @@ class DefaultCompiler(engine.Compiled):
return x
def visit_cast(self, cast, **kwargs):
- return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
+ return "CAST(%s AS %s)" % \
+ (self.process(cast.clause), self.process(cast.typeclause))
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
@@ -372,26 +380,26 @@ class DefaultCompiler(engine.Compiled):
if result_map is not None:
result_map[func.name.lower()] = (func.name, None, func.type)
- name = self.function_string(func)
-
- if util.callable(name):
- return name(*[self.process(x) for x in func.clauses])
+ disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+ if disp:
+ return disp(func, **kwargs)
else:
- return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
+ name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
+ return ".".join(func.packagenames + [name]) % \
+ {'expr':self.function_argspec(func, **kwargs)}
def function_argspec(self, func, **kwargs):
return self.process(func.clause_expr, **kwargs)
- def function_string(self, func):
- return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s"))
-
def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
- text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i)
- for i, c in enumerate(cs.selects)),
- " " + cs.keyword + " ")
+ text = (" " + cs.keyword + " ").join(
+ (self.process(c, asfrom=asfrom, parens=False, compound_index=i)
+ for i, c in enumerate(cs.selects))
+ )
+
group_by = self.process(cs._group_by_clause, asfrom=asfrom)
if group_by:
text += " GROUP BY " + group_by
@@ -408,27 +416,57 @@ class DefaultCompiler(engine.Compiled):
def visit_unary(self, unary, **kw):
s = self.process(unary.element, **kw)
if unary.operator:
- s = self.operator_string(unary.operator) + " " + s
+ s = OPERATORS[unary.operator] + s
if unary.modifier:
- s = s + " " + self.operator_string(unary.modifier)
+ s = s + OPERATORS[unary.modifier]
return s
def visit_binary(self, binary, **kwargs):
- op = self.operator_string(binary.operator)
- if util.callable(op):
- return op(self.process(binary.left), self.process(binary.right), **binary.modifiers)
- else:
- return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+
+ return self._operator_dispatch(binary.operator,
+ binary,
+ lambda opstr: self.process(binary.left) + opstr + self.process(binary.right),
+ **kwargs
+ )
- def operator_string(self, operator):
- return self.operators.get(operator, str(operator))
+ def visit_like_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+ def visit_notlike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_ilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_notilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def _operator_dispatch(self, operator, element, fn, **kw):
+ if util.callable(operator):
+ disp = getattr(self, "visit_%s" % operator.__name__, None)
+ if disp:
+ return disp(element, **kw)
+ else:
+ return fn(OPERATORS[operator])
+ else:
+ return fn(" " + operator + " ")
+
def visit_bindparam(self, bindparam, **kwargs):
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam and (existing.unique or bindparam.unique):
- raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
+ raise exc.CompileError(
+ "Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key
+ )
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
@@ -491,7 +529,7 @@ class DefaultCompiler(engine.Compiled):
if isinstance(column, sql._Label):
return column
- if select.use_labels and column._label:
+ if select and select.use_labels and column._label:
return _CompileLabel(column, column._label)
if \
@@ -501,13 +539,15 @@ class DefaultCompiler(engine.Compiled):
column.table is not None and \
not isinstance(column.table, sql.Select):
return _CompileLabel(column, sql._generated_label(column.name))
- elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
+ elif not isinstance(column,
+ (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
and (not hasattr(column, 'name') or isinstance(column, sql.Function)):
return _CompileLabel(column, column.anon_label)
else:
return column
- def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs):
+ def visit_select(self, select, asfrom=False, parens=True,
+ iswrapper=False, compound_index=1, **kwargs):
entry = self.stack and self.stack[-1] or {}
@@ -583,8 +623,10 @@ class DefaultCompiler(engine.Compiled):
return text
def get_select_precolumns(self, select):
- """Called when building a ``SELECT`` statement, position is just before column list."""
-
+ """Called when building a ``SELECT`` statement, position is just before
+ column list.
+
+ """
return select._distinct and "DISTINCT " or ""
def order_by_clause(self, select):
@@ -613,14 +655,16 @@ class DefaultCompiler(engine.Compiled):
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
if getattr(table, "schema", None):
- return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote)
+ return self.preparer.quote_schema(table.schema, table.quote_schema) + \
+ "." + self.preparer.quote(table.name, table.quote)
else:
return self.preparer.quote(table.name, table.quote)
else:
return ""
def visit_join(self, join, asfrom=False, **kwargs):
- return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
+ return (self.process(join.left, asfrom=True) + \
+ (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
def visit_sequence(self, seq):
@@ -629,41 +673,75 @@ class DefaultCompiler(engine.Compiled):
def visit_insert(self, insert_stmt):
self.isinsert = True
colparams = self._get_colparams(insert_stmt)
- preparer = self.preparer
-
- insert = ' '.join(["INSERT"] +
- [self.process(x) for x in insert_stmt._prefixes])
- if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
+ if not colparams and \
+ not self.dialect.supports_default_values and \
+ not self.dialect.supports_empty_insert:
raise exc.CompileError(
"The version of %s you are using does not support empty inserts." % self.dialect.name)
- elif not colparams and self.dialect.supports_default_values:
- return (insert + " INTO %s DEFAULT VALUES" % (
- (preparer.format_table(insert_stmt.table),)))
- else:
- return (insert + " INTO %s (%s) VALUES (%s)" %
- (preparer.format_table(insert_stmt.table),
- ', '.join([preparer.format_column(c[0])
- for c in colparams]),
- ', '.join([c[1] for c in colparams])))
+ preparer = self.preparer
+ supports_default_values = self.dialect.supports_default_values
+
+ text = "INSERT"
+
+ prefixes = [self.process(x) for x in insert_stmt._prefixes]
+ if prefixes:
+ text += " " + " ".join(prefixes)
+
+ text += " INTO " + preparer.format_table(insert_stmt.table)
+
+ if colparams or not supports_default_values:
+ text += " (%s)" % ', '.join([preparer.format_column(c[0])
+ for c in colparams])
+
+ if self.returning or insert_stmt._returning:
+ self.returning = self.returning or insert_stmt._returning
+ returning_clause = self.returning_clause(insert_stmt, self.returning)
+
+ # cheating
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
+ if not colparams and supports_default_values:
+ text += " DEFAULT VALUES"
+ else:
+ text += " VALUES (%s)" % \
+ ', '.join([c[1] for c in colparams])
+
+ if self.returning and returning_clause:
+ text += " " + returning_clause
+
+ return text
+
def visit_update(self, update_stmt):
self.stack.append({'from': set([update_stmt.table])})
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = ' '.join((
- "UPDATE",
- self.preparer.format_table(update_stmt.table),
- 'SET',
- ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
- for c in colparams)
- ))
-
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table)
+
+ text += ' SET ' + \
+ ', '.join(
+ self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
+ for c in colparams
+ )
+
+ if update_stmt._returning:
+ self.returning = update_stmt._returning
+ returning_clause = self.returning_clause(update_stmt, update_stmt._returning)
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
+ if self.returning and returning_clause:
+ text += " " + returning_clause
+
self.stack.pop(-1)
return text
@@ -681,7 +759,8 @@ class DefaultCompiler(engine.Compiled):
self.postfetch = []
self.prefetch = []
-
+ self.returning = []
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
@@ -701,6 +780,15 @@ class DefaultCompiler(engine.Compiled):
# create a list of column assignment clauses as tuples
values = []
+
+ need_pks = self.isinsert and \
+ not self.inline and \
+ not self.statement._returning
+
+ implicit_returning = need_pks and \
+ self.dialect.implicit_returning and \
+ stmt.table.implicit_returning
+
for c in stmt.table.columns:
if c.key in parameters:
value = parameters[c.key]
@@ -710,19 +798,48 @@ class DefaultCompiler(engine.Compiled):
self.postfetch.append(c)
value = self.process(value.self_group())
values.append((c, value))
+
elif isinstance(c, schema.Column):
if self.isinsert:
- if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline):
- if (((isinstance(c.default, schema.Sequence) and
- not c.default.optional) or
- not self.dialect.supports_pk_autoincrement) or
- (c.default is not None and
- not isinstance(c.default, schema.Sequence))):
- values.append((c, create_bind_param(c, None)))
- self.prefetch.append(c)
+ if c.primary_key and \
+ need_pks and \
+ (
+ c is not stmt.table._autoincrement_column or
+ not self.dialect.postfetch_lastrowid
+ ):
+
+ if implicit_returning:
+ if isinstance(c.default, schema.Sequence):
+ proc = self.process(c.default)
+ if proc is not None:
+ values.append((c, proc))
+ self.returning.append(c)
+ elif isinstance(c.default, schema.ColumnDefault) and \
+ isinstance(c.default.arg, sql.ClauseElement):
+ values.append((c, self.process(c.default.arg.self_group())))
+ self.returning.append(c)
+ elif c.default is not None:
+ values.append((c, create_bind_param(c, None)))
+ self.prefetch.append(c)
+ else:
+ self.returning.append(c)
+ else:
+ if (
+ c.default is not None and \
+ (
+ self.dialect.supports_sequences or
+ not isinstance(c.default, schema.Sequence)
+ )
+ ) or \
+ self.dialect.preexecute_autoincrement_sequences:
+
+ values.append((c, create_bind_param(c, None)))
+ self.prefetch.append(c)
+
elif isinstance(c.default, schema.ColumnDefault):
if isinstance(c.default.arg, sql.ClauseElement):
values.append((c, self.process(c.default.arg.self_group())))
+
if not c.primary_key:
# dont add primary key column to postfetch
self.postfetch.append(c)
@@ -759,9 +876,19 @@ class DefaultCompiler(engine.Compiled):
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
+ if delete_stmt._returning:
+ self.returning = delete_stmt._returning
+ returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning)
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
if delete_stmt._whereclause:
text += " WHERE " + self.process(delete_stmt._whereclause)
+ if self.returning and returning_clause:
+ text += " " + returning_clause
+
self.stack.pop(-1)
return text
@@ -775,110 +902,146 @@ class DefaultCompiler(engine.Compiled):
def visit_release_savepoint(self, savepoint_stmt):
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
- def __str__(self):
- return self.string or ''
-
-class DDLBase(engine.SchemaIterator):
- def find_alterables(self, tables):
- alterables = []
- class FindAlterables(schema.SchemaVisitor):
- def visit_foreign_key_constraint(self, constraint):
- if constraint.use_alter and constraint.table in tables:
- alterables.append(constraint)
- findalterables = FindAlterables()
- for table in tables:
- for c in table.constraints:
- findalterables.traverse(c)
- return alterables
-
- def _validate_identifier(self, ident, truncate):
- if truncate:
- if len(ident) > self.dialect.max_identifier_length:
- counter = getattr(self, 'counter', 0)
- self.counter = counter + 1
- return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
- else:
- return ident
- else:
- self.dialect.validate_identifier(ident)
- return ident
-
-
-class SchemaGenerator(DDLBase):
- def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(SchemaGenerator, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
- self.tables = tables and set(tables) or None
- self.preparer = dialect.identifier_preparer
- self.dialect = dialect
- def get_column_specification(self, column, first_pk=False):
- raise NotImplementedError()
+class DDLCompiler(engine.Compiled):
+ @property
+ def preparer(self):
+ return self.dialect.identifier_preparer
- def _can_create(self, table):
- self.dialect.validate_identifier(table.name)
- if table.schema:
- self.dialect.validate_identifier(table.schema)
- return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+ def construct_params(self, params=None):
+ return None
+
+ def visit_ddl(self, ddl, **kwargs):
+ # table events can substitute table and schema name
+ context = ddl.context
+ if isinstance(ddl.schema_item, schema.Table):
+ context = context.copy()
+
+ preparer = self.dialect.identifier_preparer
+ path = preparer.format_table_seq(ddl.schema_item)
+ if len(path) == 1:
+ table, sch = path[0], ''
+ else:
+ table, sch = path[-1], path[0]
- def visit_metadata(self, metadata):
- if self.tables:
- tables = self.tables
- else:
- tables = metadata.tables.values()
- collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
- for table in collection:
- self.traverse_single(table)
- if self.dialect.supports_alter:
- for alterable in self.find_alterables(collection):
- self.add_foreignkey(alterable)
-
- def visit_table(self, table):
- for listener in table.ddl_listeners['before-create']:
- listener('before-create', table, self.connection)
+ context.setdefault('table', table)
+ context.setdefault('schema', sch)
+ context.setdefault('fullname', preparer.format_table(ddl.schema_item))
+
+ return ddl.statement % context
- for column in table.columns:
- if column.default is not None:
- self.traverse_single(column.default)
+ def visit_create_table(self, create):
+ table = create.element
+ preparer = self.dialect.identifier_preparer
- self.append("\n" + " ".join(['CREATE'] +
- table._prefixes +
+ text = "\n" + " ".join(['CREATE'] + \
+ table._prefixes + \
['TABLE',
- self.preparer.format_table(table),
- "("]))
+ preparer.format_table(table),
+ "("])
separator = "\n"
# if only one primary key, specify it along with the column
first_pk = False
for column in table.columns:
- self.append(separator)
+ text += separator
separator = ", \n"
- self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
+ text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)
if column.primary_key:
first_pk = True
- for constraint in column.constraints:
- self.traverse_single(constraint)
+ const = " ".join(self.process(constraint) for constraint in column.constraints)
+ if const:
+ text += " " + const
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
if table.primary_key:
- self.traverse_single(table.primary_key)
- for constraint in [c for c in table.constraints if c is not table.primary_key]:
- self.traverse_single(constraint)
+ text += ", \n\t" + self.process(table.primary_key)
+
+ const = ", \n\t".join(
+ self.process(constraint) for constraint in table.constraints
+ if constraint is not table.primary_key
+ and constraint.inline_ddl
+ and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False))
+ )
+ if const:
+ text += ", \n\t" + const
+
+ text += "\n)%s\n\n" % self.post_create_table(table)
+ return text
+
+ def visit_drop_table(self, drop):
+ ret = "\nDROP TABLE " + self.preparer.format_table(drop.element)
+ if drop.cascade:
+ ret += " CASCADE CONSTRAINTS"
+ return ret
+
+ def visit_create_index(self, create):
+ index = create.element
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ text += "INDEX %s ON %s (%s)" \
+ % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+ preparer.format_table(index.table),
+ ', '.join(preparer.quote(c.name, c.quote)
+ for c in index.columns))
+ return text
- self.append("\n)%s\n\n" % self.post_create_table(table))
- self.execute()
+ def visit_drop_index(self, drop):
+ index = drop.element
+ return "\nDROP INDEX " + \
+ self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
- if hasattr(table, 'indexes'):
- for index in table.indexes:
- self.traverse_single(index)
+ def visit_add_constraint(self, create):
+ preparer = self.preparer
+ return "ALTER TABLE %s ADD %s" % (
+ self.preparer.format_table(create.element.table),
+ self.process(create.element)
+ )
+
+ def visit_drop_constraint(self, drop):
+ preparer = self.preparer
+ return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
+ self.preparer.format_table(drop.element.table),
+ self.preparer.format_constraint(drop.element),
+ " CASCADE" if drop.cascade else ""
+ )
+
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column) + " " + \
+ self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
- for listener in table.ddl_listeners['after-create']:
- listener('after-create', table, self.connection)
+ if not column.nullable:
+ colspec += " NOT NULL"
+ return colspec
def post_create_table(self, table):
return ''
+ def _compile(self, tocompile, parameters):
+ """compile the given string/parameters using this SchemaGenerator's dialect."""
+
+ compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
+ compiler.compile()
+ return compiler
+
+ def _validate_identifier(self, ident, truncate):
+ if truncate:
+ if len(ident) > self.dialect.max_identifier_length:
+ counter = getattr(self, 'counter', 0)
+ self.counter = counter + 1
+ return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
+ else:
+ return ident
+ else:
+ self.dialect.validate_identifier(ident)
+ return ident
+
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, basestring):
@@ -888,149 +1051,190 @@ class SchemaGenerator(DDLBase):
else:
return None
- def _compile(self, tocompile, parameters):
- """compile the given string/parameters using this SchemaGenerator's dialect."""
- compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
- compiler.compile()
- return compiler
-
def visit_check_constraint(self, constraint):
- self.append(", \n\t")
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- self.preparer.format_constraint(constraint))
- self.append(" CHECK (%s)" % constraint.sqltext)
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % \
+ self.preparer.format_constraint(constraint)
+ text += " CHECK (%s)" % constraint.sqltext
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_column_check_constraint(self, constraint):
- self.append(" CHECK (%s)" % constraint.sqltext)
- self.define_constraint_deferrability(constraint)
+ text = " CHECK (%s)" % constraint.sqltext
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return
- self.append(", \n\t")
+ return ''
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
- self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
- for c in constraint))
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+ text += "PRIMARY KEY "
+ text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
+ for c in constraint)
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_foreign_key_constraint(self, constraint):
- if constraint.use_alter and self.dialect.supports_alter:
- return
- self.append(", \n\t ")
- self.define_foreign_key(constraint)
-
- def add_foreignkey(self, constraint):
- self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
- self.define_foreign_key(constraint)
- self.execute()
-
- def define_foreign_key(self, constraint):
- preparer = self.preparer
+ preparer = self.dialect.identifier_preparer
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- preparer.format_constraint(constraint))
- table = list(constraint.elements)[0].column.table
- self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+ text += "CONSTRAINT %s " % \
+ preparer.format_constraint(constraint)
+ remote_table = list(constraint._elements.values())[0].column.table
+ text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
', '.join(preparer.quote(f.parent.name, f.parent.quote)
- for f in constraint.elements),
- preparer.format_table(table),
+ for f in constraint._elements.values()),
+ preparer.format_table(remote_table),
', '.join(preparer.quote(f.column.name, f.column.quote)
- for f in constraint.elements)
- ))
- if constraint.ondelete is not None:
- self.append(" ON DELETE %s" % constraint.ondelete)
- if constraint.onupdate is not None:
- self.append(" ON UPDATE %s" % constraint.onupdate)
- self.define_constraint_deferrability(constraint)
+ for f in constraint._elements.values())
+ )
+ text += self.define_constraint_cascades(constraint)
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_unique_constraint(self, constraint):
- self.append(", \n\t")
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)))
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+ text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))
+ text += self.define_constraint_deferrability(constraint)
+ return text
+ def define_constraint_cascades(self, constraint):
+ text = ""
+ if constraint.ondelete is not None:
+ text += " ON DELETE %s" % constraint.ondelete
+ if constraint.onupdate is not None:
+ text += " ON UPDATE %s" % constraint.onupdate
+ return text
+
def define_constraint_deferrability(self, constraint):
+ text = ""
if constraint.deferrable is not None:
if constraint.deferrable:
- self.append(" DEFERRABLE")
+ text += " DEFERRABLE"
else:
- self.append(" NOT DEFERRABLE")
+ text += " NOT DEFERRABLE"
if constraint.initially is not None:
- self.append(" INITIALLY %s" % constraint.initially)
+ text += " INITIALLY %s" % constraint.initially
+ return text
+
+
+class GenericTypeCompiler(engine.TypeCompiler):
+ def visit_CHAR(self, type_):
+ return "CHAR" + (type_.length and "(%d)" % type_.length or "")
- def visit_column(self, column):
- pass
+ def visit_NCHAR(self, type_):
+ return "NCHAR" + (type_.length and "(%d)" % type_.length or "")
+
+ def visit_FLOAT(self, type_):
+ return "FLOAT"
- def visit_index(self, index):
- preparer = self.preparer
- self.append("CREATE ")
- if index.unique:
- self.append("UNIQUE ")
- self.append("INDEX %s ON %s (%s)" \
- % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
- preparer.format_table(index.table),
- ', '.join(preparer.quote(c.name, c.quote)
- for c in index.columns)))
- self.execute()
+ def visit_NUMERIC(self, type_):
+ if type_.precision is None:
+ return "NUMERIC"
+ else:
+ return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
+ def visit_DECIMAL(self, type_):
+ return "DECIMAL"
+
+ def visit_INTEGER(self, type_):
+ return "INTEGER"
-class SchemaDropper(DDLBase):
- def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(SchemaDropper, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
- self.tables = tables
- self.preparer = dialect.identifier_preparer
- self.dialect = dialect
+ def visit_SMALLINT(self, type_):
+ return "SMALLINT"
- def visit_metadata(self, metadata):
- if self.tables:
- tables = self.tables
- else:
- tables = metadata.tables.values()
- collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
- if self.dialect.supports_alter:
- for alterable in self.find_alterables(collection):
- self.drop_foreignkey(alterable)
- for table in collection:
- self.traverse_single(table)
-
- def _can_drop(self, table):
- self.dialect.validate_identifier(table.name)
- if table.schema:
- self.dialect.validate_identifier(table.schema)
- return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
-
- def visit_index(self, index):
- self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote))
- self.execute()
-
- def drop_foreignkey(self, constraint):
- self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (
- self.preparer.format_table(constraint.table),
- self.preparer.format_constraint(constraint)))
- self.execute()
-
- def visit_table(self, table):
- for listener in table.ddl_listeners['before-drop']:
- listener('before-drop', table, self.connection)
+ def visit_BIGINT(self, type_):
+ return "BIGINT"
- for column in table.columns:
- if column.default is not None:
- self.traverse_single(column.default)
+ def visit_TIMESTAMP(self, type_):
+ return 'TIMESTAMP'
+
+ def visit_DATETIME(self, type_):
+ return "DATETIME"
+
+ def visit_DATE(self, type_):
+ return "DATE"
+
+ def visit_TIME(self, type_):
+ return "TIME"
+
+ def visit_CLOB(self, type_):
+ return "CLOB"
- self.append("\nDROP TABLE " + self.preparer.format_table(table))
- self.execute()
+ def visit_NCLOB(self, type_):
+ return "NCLOB"
- for listener in table.ddl_listeners['after-drop']:
- listener('after-drop', table, self.connection)
+ def visit_VARCHAR(self, type_):
+ return "VARCHAR" + (type_.length and "(%d)" % type_.length or "")
+ def visit_NVARCHAR(self, type_):
+ return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "")
+ def visit_BLOB(self, type_):
+ return "BLOB"
+
+ def visit_BOOLEAN(self, type_):
+ return "BOOLEAN"
+
+ def visit_TEXT(self, type_):
+ return "TEXT"
+
+ def visit_binary(self, type_):
+ return self.visit_BLOB(type_)
+
+ def visit_boolean(self, type_):
+ return self.visit_BOOLEAN(type_)
+
+ def visit_time(self, type_):
+ return self.visit_TIME(type_)
+
+ def visit_datetime(self, type_):
+ return self.visit_DATETIME(type_)
+
+ def visit_date(self, type_):
+ return self.visit_DATE(type_)
+
+ def visit_big_integer(self, type_):
+ return self.visit_BIGINT(type_)
+
+ def visit_small_integer(self, type_):
+ return self.visit_SMALLINT(type_)
+
+ def visit_integer(self, type_):
+ return self.visit_INTEGER(type_)
+
+ def visit_float(self, type_):
+ return self.visit_FLOAT(type_)
+
+ def visit_numeric(self, type_):
+ return self.visit_NUMERIC(type_)
+
+ def visit_string(self, type_):
+ return self.visit_VARCHAR(type_)
+
+ def visit_unicode(self, type_):
+ return self.visit_VARCHAR(type_)
+
+ def visit_text(self, type_):
+ return self.visit_TEXT(type_)
+
+ def visit_unicode_text(self, type_):
+ return self.visit_TEXT(type_)
+
+ def visit_null(self, type_):
+ raise NotImplementedError("Can't generate DDL for the null type")
+
+ def visit_type_decorator(self, type_):
+ return self.process(type_.type_engine(self.dialect))
+
+ def visit_user_defined(self, type_):
+ return type_.get_col_spec()
+
class IdentifierPreparer(object):
"""Handle quoting and case-folding of identifiers based on options."""
@@ -1176,24 +1380,24 @@ class IdentifierPreparer(object):
else:
return (self.format_table(table, use_schema=False), )
+ @util.memoized_property
+ def _r_identifiers(self):
+ initial, final, escaped_final = \
+ [re.escape(s) for s in
+ (self.initial_quote, self.final_quote,
+ self._escape_identifier(self.final_quote))]
+ r = re.compile(
+ r'(?:'
+ r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
+ r'|([^\.]+))(?=\.|$))+' %
+ { 'initial': initial,
+ 'final': final,
+ 'escaped': escaped_final })
+ return r
+
def unformat_identifiers(self, identifiers):
"""Unpack 'schema.table.column'-like strings into components."""
- try:
- r = self._r_identifiers
- except AttributeError:
- initial, final, escaped_final = \
- [re.escape(s) for s in
- (self.initial_quote, self.final_quote,
- self._escape_identifier(self.final_quote))]
- r = re.compile(
- r'(?:'
- r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
- r'|([^\.]+))(?=\.|$))+' %
- { 'initial': initial,
- 'final': final,
- 'escaped': escaped_final })
- self._r_identifiers = r
-
+ r = self._r_identifiers
return [self._unescape_identifier(i)
for i in [a or b for a, b in r.findall(identifiers)]]