diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 872 |
1 files changed, 538 insertions, 334 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6af65ec14..6bfad4a76 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -6,19 +6,23 @@ """Base SQL and DDL compiler implementations. -Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is -responsible for generating all SQL query strings, as well as -:class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper` -which issue CREATE and DROP DDL for tables, sequences, and indexes. - -The elements in this module are used by public-facing constructs like -:class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`. -While dialect authors will want to be familiar with this module for the purpose of -creating database-specific compilers and schema generators, the module -is otherwise internal to SQLAlchemy. +Classes provided include: + +:class:`~sqlalchemy.sql.compiler.SQLCompiler` - renders SQL +strings + +:class:`~sqlalchemy.sql.compiler.DDLCompiler` - renders DDL +(data definition language) strings + +:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders +type specification strings. + +To generate user-defined SQL strings, see +:module:`~sqlalchemy.ext.compiler`. + """ -import string, re +import re from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions, util as sql_util, visitors from sqlalchemy.sql import expression as sql @@ -58,40 +62,43 @@ BIND_TEMPLATES = { OPERATORS = { - operators.and_ : 'AND', - operators.or_ : 'OR', - operators.inv : 'NOT', - operators.add : '+', - operators.mul : '*', - operators.sub : '-', - operators.div : '/', - operators.mod : '%', - operators.truediv : '/', - operators.lt : '<', - operators.le : '<=', - operators.ne : '!=', - operators.gt : '>', - operators.ge : '>=', - operators.eq : '=', - operators.distinct_op : 'DISTINCT', - operators.concat_op : '||', - operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.between_op : 'BETWEEN', - operators.match_op : 'MATCH', - operators.in_op : 'IN', - operators.notin_op : 'NOT IN', + # binary + operators.and_ : ' AND ', + operators.or_ : ' OR ', + operators.add : ' + ', + operators.mul : ' * ', + operators.sub : ' - ', + # Py2K + operators.div : ' / ', + # end Py2K + operators.mod : ' % ', + operators.truediv : ' / ', + operators.lt : ' < ', + operators.le : ' <= ', + operators.ne : ' != ', + operators.gt : ' > ', + operators.ge : ' >= ', + operators.eq : ' = ', + operators.concat_op : ' || ', + operators.between_op : ' BETWEEN ', + operators.match_op : ' MATCH ', + operators.in_op : ' IN ', + operators.notin_op : ' NOT IN ', operators.comma_op : ', ', - operators.desc_op : 'DESC', - operators.asc_op : 'ASC', - operators.from_ : 'FROM', - operators.as_ : 'AS', - operators.exists : 'EXISTS', - operators.is_ : 'IS', - operators.isnot : 'IS NOT', - operators.collate : 'COLLATE', + operators.from_ : ' FROM ', + operators.as_ : ' AS ', + operators.is_ : ' IS ', + operators.isnot : ' IS NOT ', + operators.collate : ' COLLATE ', + + # unary + operators.exists : 'EXISTS ', + operators.distinct_op : 'DISTINCT ', + operators.inv : 'NOT ', + + # modifiers + operators.desc_op : ' DESC', + operators.asc_op : ' ASC', } FUNCTIONS = { @@ -140,7 +147,7 @@ class _CompileLabel(visitors.Visitable): def quote(self): return self.element.quote -class DefaultCompiler(engine.Compiled): +class SQLCompiler(engine.Compiled): """Default implementation of Compiled. Compiles ClauseElements into SQL strings. Uses a similar visit @@ -148,14 +155,14 @@ class DefaultCompiler(engine.Compiled): """ - operators = OPERATORS - functions = FUNCTIONS extract_map = EXTRACT_MAP - # if we are insert/update/delete. - # set to true when we visit an INSERT, UPDATE or DELETE + # class-level defaults which can be set at the instance + # level to define if this Compiled instance represents + # INSERT/UPDATE/DELETE isdelete = isinsert = isupdate = False - + returning = None + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -170,7 +177,9 @@ class DefaultCompiler(engine.Compiled): statement. """ - engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs) + engine.Compiled.__init__(self, dialect, statement, **kwargs) + + self.column_keys = column_keys # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) self.inline = inline or getattr(statement, 'inline', False) @@ -210,12 +219,6 @@ class DefaultCompiler(engine.Compiled): # or dialect.max_identifier_length self.truncated_names = {} - def compile(self): - self.string = self.process(self.statement) - - def process(self, obj, **kwargs): - return obj._compiler_dispatch(self, **kwargs) - def is_subquery(self): return len(self.stack) > 1 @@ -223,7 +226,6 @@ class DefaultCompiler(engine.Compiled): """return a dictionary of bind parameter keys and values""" if params: - params = util.column_dict(params) pd = {} for bindparam, name in self.bind_names.iteritems(): for paramname in (bindparam.key, bindparam.shortname, name): @@ -245,7 +247,10 @@ class DefaultCompiler(engine.Compiled): pd[self.bind_names[bindparam]] = bindparam.value return pd - params = property(construct_params) + params = property(construct_params, doc=""" + Return the bind params for this compiled object. + + """) def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -267,10 +272,11 @@ class DefaultCompiler(engine.Compiled): self._truncated_identifier("colident", label.name) or label.name if result_map is not None: - result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) + result_map[labelname.lower()] = \ + (label.name, (label, label.element, labelname), label.element.type) - return self.process(label.element) + " " + \ - self.operator_string(operators.as_) + " " + \ + return self.process(label.element) + \ + OPERATORS[operators.as_] + \ self.preparer.format_label(label, labelname) else: return self.process(label.element) @@ -292,14 +298,17 @@ class DefaultCompiler(engine.Compiled): return name else: if column.table.schema: - schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.' + schema_prefix = self.preparer.quote_schema( + column.table.schema, + column.table.quote_schema) + '.' else: schema_prefix = '' tablename = column.table.name tablename = isinstance(tablename, sql._generated_label) and \ self._truncated_identifier("alias", tablename) or tablename - return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name + return schema_prefix + \ + self.preparer.quote(tablename, column.table.quote) + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -314,7 +323,7 @@ class DefaultCompiler(engine.Compiled): return index.name def visit_typeclause(self, typeclause, **kwargs): - return typeclause.type.dialect_impl(self.dialect).get_col_spec() + return self.dialect.type_compiler.process(typeclause.type) def post_process_text(self, text): return text @@ -343,10 +352,8 @@ class DefaultCompiler(engine.Compiled): sep = clauselist.operator if sep is None: sep = " " - elif sep is operators.comma_op: - sep = ', ' else: - sep = " " + self.operator_string(clauselist.operator) + " " + sep = OPERATORS[clauselist.operator] return sep.join(s for s in (self.process(c) for c in clauselist.clauses) if s is not None) @@ -362,7 +369,8 @@ class DefaultCompiler(engine.Compiled): return x def visit_cast(self, cast, **kwargs): - return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) + return "CAST(%s AS %s)" % \ + (self.process(cast.clause), self.process(cast.typeclause)) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) @@ -372,26 +380,26 @@ class DefaultCompiler(engine.Compiled): if result_map is not None: result_map[func.name.lower()] = (func.name, None, func.type) - name = self.function_string(func) - - if util.callable(name): - return name(*[self.process(x) for x in func.clauses]) + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) else: - return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} + name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") + return ".".join(func.packagenames + [name]) % \ + {'expr':self.function_argspec(func, **kwargs)} def function_argspec(self, func, **kwargs): return self.process(func.clause_expr, **kwargs) - def function_string(self, func): - return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s")) - def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): entry = self.stack and self.stack[-1] or {} self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) - text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i) - for i, c in enumerate(cs.selects)), - " " + cs.keyword + " ") + text = (" " + cs.keyword + " ").join( + (self.process(c, asfrom=asfrom, parens=False, compound_index=i) + for i, c in enumerate(cs.selects)) + ) + group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by @@ -408,27 +416,57 @@ class DefaultCompiler(engine.Compiled): def visit_unary(self, unary, **kw): s = self.process(unary.element, **kw) if unary.operator: - s = self.operator_string(unary.operator) + " " + s + s = OPERATORS[unary.operator] + s if unary.modifier: - s = s + " " + self.operator_string(unary.modifier) + s = s + OPERATORS[unary.modifier] return s def visit_binary(self, binary, **kwargs): - op = self.operator_string(binary.operator) - if util.callable(op): - return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) - else: - return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + return self._operator_dispatch(binary.operator, + binary, + lambda opstr: self.process(binary.left) + opstr + self.process(binary.right), + **kwargs + ) - def operator_string(self, operator): - return self.operators.get(operator, str(operator)) + def visit_like_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + def visit_notlike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def _operator_dispatch(self, operator, element, fn, **kw): + if util.callable(operator): + disp = getattr(self, "visit_%s" % operator.__name__, None) + if disp: + return disp(element, **kw) + else: + return fn(OPERATORS[operator]) + else: + return fn(" " + operator + " ") + def visit_bindparam(self, bindparam, **kwargs): name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] if existing is not bindparam and (existing.unique or bindparam.unique): - raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + raise exc.CompileError( + "Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key + ) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) @@ -491,7 +529,7 @@ class DefaultCompiler(engine.Compiled): if isinstance(column, sql._Label): return column - if select.use_labels and column._label: + if select and select.use_labels and column._label: return _CompileLabel(column, column._label) if \ @@ -501,13 +539,15 @@ class DefaultCompiler(engine.Compiled): column.table is not None and \ not isinstance(column.table, sql.Select): return _CompileLabel(column, sql._generated_label(column.name)) - elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ + elif not isinstance(column, + (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ and (not hasattr(column, 'name') or isinstance(column, sql.Function)): return _CompileLabel(column, column.anon_label) else: return column - def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs): + def visit_select(self, select, asfrom=False, parens=True, + iswrapper=False, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -583,8 +623,10 @@ class DefaultCompiler(engine.Compiled): return text def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just before column list.""" - + """Called when building a ``SELECT`` statement, position is just before + column list. + + """ return select._distinct and "DISTINCT " or "" def order_by_clause(self, select): @@ -613,14 +655,16 @@ class DefaultCompiler(engine.Compiled): def visit_table(self, table, asfrom=False, **kwargs): if asfrom: if getattr(table, "schema", None): - return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) + return self.preparer.quote_schema(table.schema, table.quote_schema) + \ + "." + self.preparer.quote(table.name, table.quote) else: return self.preparer.quote(table.name, table.quote) else: return "" def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + return (self.process(join.left, asfrom=True) + \ + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) def visit_sequence(self, seq): @@ -629,41 +673,75 @@ class DefaultCompiler(engine.Compiled): def visit_insert(self, insert_stmt): self.isinsert = True colparams = self._get_colparams(insert_stmt) - preparer = self.preparer - - insert = ' '.join(["INSERT"] + - [self.process(x) for x in insert_stmt._prefixes]) - if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: + if not colparams and \ + not self.dialect.supports_default_values and \ + not self.dialect.supports_empty_insert: raise exc.CompileError( "The version of %s you are using does not support empty inserts." % self.dialect.name) - elif not colparams and self.dialect.supports_default_values: - return (insert + " INTO %s DEFAULT VALUES" % ( - (preparer.format_table(insert_stmt.table),))) - else: - return (insert + " INTO %s (%s) VALUES (%s)" % - (preparer.format_table(insert_stmt.table), - ', '.join([preparer.format_column(c[0]) - for c in colparams]), - ', '.join([c[1] for c in colparams]))) + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT" + + prefixes = [self.process(x) for x in insert_stmt._prefixes] + if prefixes: + text += " " + " ".join(prefixes) + + text += " INTO " + preparer.format_table(insert_stmt.table) + + if colparams or not supports_default_values: + text += " (%s)" % ', '.join([preparer.format_column(c[0]) + for c in colparams]) + + if self.returning or insert_stmt._returning: + self.returning = self.returning or insert_stmt._returning + returning_clause = self.returning_clause(insert_stmt, self.returning) + + # cheating + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + + if not colparams and supports_default_values: + text += " DEFAULT VALUES" + else: + text += " VALUES (%s)" % \ + ', '.join([c[1] for c in colparams]) + + if self.returning and returning_clause: + text += " " + returning_clause + + return text + def visit_update(self, update_stmt): self.stack.append({'from': set([update_stmt.table])}) self.isupdate = True colparams = self._get_colparams(update_stmt) - text = ' '.join(( - "UPDATE", - self.preparer.format_table(update_stmt.table), - 'SET', - ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] - for c in colparams) - )) - + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + + text += ' SET ' + \ + ', '.join( + self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] + for c in colparams + ) + + if update_stmt._returning: + self.returning = update_stmt._returning + returning_clause = self.returning_clause(update_stmt, update_stmt._returning) + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) + if self.returning and returning_clause: + text += " " + returning_clause + self.stack.pop(-1) return text @@ -681,7 +759,8 @@ class DefaultCompiler(engine.Compiled): self.postfetch = [] self.prefetch = [] - + self.returning = [] + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: @@ -701,6 +780,15 @@ class DefaultCompiler(engine.Compiled): # create a list of column assignment clauses as tuples values = [] + + need_pks = self.isinsert and \ + not self.inline and \ + not self.statement._returning + + implicit_returning = need_pks and \ + self.dialect.implicit_returning and \ + stmt.table.implicit_returning + for c in stmt.table.columns: if c.key in parameters: value = parameters[c.key] @@ -710,19 +798,48 @@ class DefaultCompiler(engine.Compiled): self.postfetch.append(c) value = self.process(value.self_group()) values.append((c, value)) + elif isinstance(c, schema.Column): if self.isinsert: - if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): - if (((isinstance(c.default, schema.Sequence) and - not c.default.optional) or - not self.dialect.supports_pk_autoincrement) or - (c.default is not None and - not isinstance(c.default, schema.Sequence))): - values.append((c, create_bind_param(c, None))) - self.prefetch.append(c) + if c.primary_key and \ + need_pks and \ + ( + c is not stmt.table._autoincrement_column or + not self.dialect.postfetch_lastrowid + ): + + if implicit_returning: + if isinstance(c.default, schema.Sequence): + proc = self.process(c.default) + if proc is not None: + values.append((c, proc)) + self.returning.append(c) + elif isinstance(c.default, schema.ColumnDefault) and \ + isinstance(c.default.arg, sql.ClauseElement): + values.append((c, self.process(c.default.arg.self_group()))) + self.returning.append(c) + elif c.default is not None: + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + else: + self.returning.append(c) + else: + if ( + c.default is not None and \ + ( + self.dialect.supports_sequences or + not isinstance(c.default, schema.Sequence) + ) + ) or \ + self.dialect.preexecute_autoincrement_sequences: + + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): values.append((c, self.process(c.default.arg.self_group()))) + if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) @@ -759,9 +876,19 @@ class DefaultCompiler(engine.Compiled): text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + if delete_stmt._returning: + self.returning = delete_stmt._returning + returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning) + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) + if self.returning and returning_clause: + text += " " + returning_clause + self.stack.pop(-1) return text @@ -775,110 +902,146 @@ class DefaultCompiler(engine.Compiled): def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - def __str__(self): - return self.string or '' - -class DDLBase(engine.SchemaIterator): - def find_alterables(self, tables): - alterables = [] - class FindAlterables(schema.SchemaVisitor): - def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and constraint.table in tables: - alterables.append(constraint) - findalterables = FindAlterables() - for table in tables: - for c in table.constraints: - findalterables.traverse(c) - return alterables - - def _validate_identifier(self, ident, truncate): - if truncate: - if len(ident) > self.dialect.max_identifier_length: - counter = getattr(self, 'counter', 0) - self.counter = counter + 1 - return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] - else: - return ident - else: - self.dialect.validate_identifier(ident) - return ident - - -class SchemaGenerator(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(SchemaGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables and set(tables) or None - self.preparer = dialect.identifier_preparer - self.dialect = dialect - def get_column_specification(self, column, first_pk=False): - raise NotImplementedError() +class DDLCompiler(engine.Compiled): + @property + def preparer(self): + return self.dialect.identifier_preparer - def _can_create(self, table): - self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) - return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + def construct_params(self, params=None): + return None + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.schema_item, schema.Table): + context = context.copy() + + preparer = self.dialect.identifier_preparer + path = preparer.format_table_seq(ddl.schema_item) + if len(path) == 1: + table, sch = path[0], '' + else: + table, sch = path[-1], path[0] - def visit_metadata(self, metadata): - if self.tables: - tables = self.tables - else: - tables = metadata.tables.values() - collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] - for table in collection: - self.traverse_single(table) - if self.dialect.supports_alter: - for alterable in self.find_alterables(collection): - self.add_foreignkey(alterable) - - def visit_table(self, table): - for listener in table.ddl_listeners['before-create']: - listener('before-create', table, self.connection) + context.setdefault('table', table) + context.setdefault('schema', sch) + context.setdefault('fullname', preparer.format_table(ddl.schema_item)) + + return ddl.statement % context - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + def visit_create_table(self, create): + table = create.element + preparer = self.dialect.identifier_preparer - self.append("\n" + " ".join(['CREATE'] + - table._prefixes + + text = "\n" + " ".join(['CREATE'] + \ + table._prefixes + \ ['TABLE', - self.preparer.format_table(table), - "("])) + preparer.format_table(table), + "("]) separator = "\n" # if only one primary key, specify it along with the column first_pk = False for column in table.columns: - self.append(separator) + text += separator separator = ", \n" - self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) + text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk) if column.primary_key: first_pk = True - for constraint in column.constraints: - self.traverse_single(constraint) + const = " ".join(self.process(constraint) for constraint in column.constraints) + if const: + text += " " + const # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if table.primary_key: - self.traverse_single(table.primary_key) - for constraint in [c for c in table.constraints if c is not table.primary_key]: - self.traverse_single(constraint) + text += ", \n\t" + self.process(table.primary_key) + + const = ", \n\t".join( + self.process(constraint) for constraint in table.constraints + if constraint is not table.primary_key + and constraint.inline_ddl + and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False)) + ) + if const: + text += ", \n\t" + const + + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def visit_drop_table(self, drop): + ret = "\nDROP TABLE " + self.preparer.format_table(drop.element) + if drop.cascade: + ret += " CASCADE CONSTRAINTS" + return ret + + def visit_create_index(self, create): + index = create.element + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + ', '.join(preparer.quote(c.name, c.quote) + for c in index.columns)) + return text - self.append("\n)%s\n\n" % self.post_create_table(table)) - self.execute() + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX " + \ + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - if hasattr(table, 'indexes'): - for index in table.indexes: - self.traverse_single(index) + def visit_add_constraint(self, create): + preparer = self.preparer + return "ALTER TABLE %s ADD %s" % ( + self.preparer.format_table(create.element.table), + self.process(create.element) + ) + + def visit_drop_constraint(self, drop): + preparer = self.preparer + return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( + self.preparer.format_table(drop.element.table), + self.preparer.format_constraint(drop.element), + " CASCADE" if drop.cascade else "" + ) + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + \ + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default - for listener in table.ddl_listeners['after-create']: - listener('after-create', table, self.connection) + if not column.nullable: + colspec += " NOT NULL" + return colspec def post_create_table(self, table): return '' + def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" + + compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) + compiler.compile() + return compiler + + def _validate_identifier(self, ident, truncate): + if truncate: + if len(ident) > self.dialect.max_identifier_length: + counter = getattr(self, 'counter', 0) + self.counter = counter + 1 + return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] + else: + return ident + else: + self.dialect.validate_identifier(ident) + return ident + def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): if isinstance(column.server_default.arg, basestring): @@ -888,149 +1051,190 @@ class SchemaGenerator(DDLBase): else: return None - def _compile(self, tocompile, parameters): - """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) - compiler.compile() - return compiler - def visit_check_constraint(self, constraint): - self.append(", \n\t") + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" CHECK (%s)" % constraint.sqltext) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text def visit_column_check_constraint(self, constraint): - self.append(" CHECK (%s)" % constraint.sqltext) - self.define_constraint_deferrability(constraint) + text = " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return - self.append(", \n\t") + return '' + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) - for c in constraint)) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += "PRIMARY KEY " + text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) + for c in constraint) + text += self.define_constraint_deferrability(constraint) + return text def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and self.dialect.supports_alter: - return - self.append(", \n\t ") - self.define_foreign_key(constraint) - - def add_foreignkey(self, constraint): - self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) - self.define_foreign_key(constraint) - self.execute() - - def define_foreign_key(self, constraint): - preparer = self.preparer + preparer = self.dialect.identifier_preparer + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - preparer.format_constraint(constraint)) - table = list(constraint.elements)[0].column.table - self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + text += "CONSTRAINT %s " % \ + preparer.format_constraint(constraint) + remote_table = list(constraint._elements.values())[0].column.table + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ', '.join(preparer.quote(f.parent.name, f.parent.quote) - for f in constraint.elements), - preparer.format_table(table), + for f in constraint._elements.values()), + preparer.format_table(remote_table), ', '.join(preparer.quote(f.column.name, f.column.quote) - for f in constraint.elements) - )) - if constraint.ondelete is not None: - self.append(" ON DELETE %s" % constraint.ondelete) - if constraint.onupdate is not None: - self.append(" ON UPDATE %s" % constraint.onupdate) - self.define_constraint_deferrability(constraint) + for f in constraint._elements.values()) + ) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text def visit_unique_constraint(self, constraint): - self.append(", \n\t") + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) + text += self.define_constraint_deferrability(constraint) + return text + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + if constraint.onupdate is not None: + text += " ON UPDATE %s" % constraint.onupdate + return text + def define_constraint_deferrability(self, constraint): + text = "" if constraint.deferrable is not None: if constraint.deferrable: - self.append(" DEFERRABLE") + text += " DEFERRABLE" else: - self.append(" NOT DEFERRABLE") + text += " NOT DEFERRABLE" if constraint.initially is not None: - self.append(" INITIALLY %s" % constraint.initially) + text += " INITIALLY %s" % constraint.initially + return text + + +class GenericTypeCompiler(engine.TypeCompiler): + def visit_CHAR(self, type_): + return "CHAR" + (type_.length and "(%d)" % type_.length or "") - def visit_column(self, column): - pass + def visit_NCHAR(self, type_): + return "NCHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_FLOAT(self, type_): + return "FLOAT" - def visit_index(self, index): - preparer = self.preparer - self.append("CREATE ") - if index.unique: - self.append("UNIQUE ") - self.append("INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - ', '.join(preparer.quote(c.name, c.quote) - for c in index.columns))) - self.execute() + def visit_NUMERIC(self, type_): + if type_.precision is None: + return "NUMERIC" + else: + return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale} + def visit_DECIMAL(self, type_): + return "DECIMAL" + + def visit_INTEGER(self, type_): + return "INTEGER" -class SchemaDropper(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(SchemaDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables - self.preparer = dialect.identifier_preparer - self.dialect = dialect + def visit_SMALLINT(self, type_): + return "SMALLINT" - def visit_metadata(self, metadata): - if self.tables: - tables = self.tables - else: - tables = metadata.tables.values() - collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] - if self.dialect.supports_alter: - for alterable in self.find_alterables(collection): - self.drop_foreignkey(alterable) - for table in collection: - self.traverse_single(table) - - def _can_drop(self, table): - self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) - return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) - - def visit_index(self, index): - self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)) - self.execute() - - def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( - self.preparer.format_table(constraint.table), - self.preparer.format_constraint(constraint))) - self.execute() - - def visit_table(self, table): - for listener in table.ddl_listeners['before-drop']: - listener('before-drop', table, self.connection) + def visit_BIGINT(self, type_): + return "BIGINT" - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_CLOB(self, type_): + return "CLOB" - self.append("\nDROP TABLE " + self.preparer.format_table(table)) - self.execute() + def visit_NCLOB(self, type_): + return "NCLOB" - for listener in table.ddl_listeners['after-drop']: - listener('after-drop', table, self.connection) + def visit_VARCHAR(self, type_): + return "VARCHAR" + (type_.length and "(%d)" % type_.length or "") + def visit_NVARCHAR(self, type_): + return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "") + def visit_BLOB(self, type_): + return "BLOB" + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + + def visit_TEXT(self, type_): + return "TEXT" + + def visit_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_boolean(self, type_): + return self.visit_BOOLEAN(type_) + + def visit_time(self, type_): + return self.visit_TIME(type_) + + def visit_datetime(self, type_): + return self.visit_DATETIME(type_) + + def visit_date(self, type_): + return self.visit_DATE(type_) + + def visit_big_integer(self, type_): + return self.visit_BIGINT(type_) + + def visit_small_integer(self, type_): + return self.visit_SMALLINT(type_) + + def visit_integer(self, type_): + return self.visit_INTEGER(type_) + + def visit_float(self, type_): + return self.visit_FLOAT(type_) + + def visit_numeric(self, type_): + return self.visit_NUMERIC(type_) + + def visit_string(self, type_): + return self.visit_VARCHAR(type_) + + def visit_unicode(self, type_): + return self.visit_VARCHAR(type_) + + def visit_text(self, type_): + return self.visit_TEXT(type_) + + def visit_unicode_text(self, type_): + return self.visit_TEXT(type_) + + def visit_null(self, type_): + raise NotImplementedError("Can't generate DDL for the null type") + + def visit_type_decorator(self, type_): + return self.process(type_.type_engine(self.dialect)) + + def visit_user_defined(self, type_): + return type_.get_col_spec() + class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" @@ -1176,24 +1380,24 @@ class IdentifierPreparer(object): else: return (self.format_table(table, use_schema=False), ) + @util.memoized_property + def _r_identifiers(self): + initial, final, escaped_final = \ + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] + r = re.compile( + r'(?:' + r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' + r'|([^\.]+))(?=\.|$))+' % + { 'initial': initial, + 'final': final, + 'escaped': escaped_final }) + return r + def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" - try: - r = self._r_identifiers - except AttributeError: - initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] - r = re.compile( - r'(?:' - r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' - r'|([^\.]+))(?=\.|$))+' % - { 'initial': initial, - 'final': final, - 'escaped': escaped_final }) - self._r_identifiers = r - + r = self._r_identifiers return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]] |
