diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
| commit | 8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch) | |
| tree | ae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/sql | |
| parent | 7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff) | |
| download | sqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz | |
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 872 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 675 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 5 |
4 files changed, 1016 insertions, 546 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6af65ec14..6bfad4a76 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -6,19 +6,23 @@ """Base SQL and DDL compiler implementations. -Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is -responsible for generating all SQL query strings, as well as -:class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper` -which issue CREATE and DROP DDL for tables, sequences, and indexes. - -The elements in this module are used by public-facing constructs like -:class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`. -While dialect authors will want to be familiar with this module for the purpose of -creating database-specific compilers and schema generators, the module -is otherwise internal to SQLAlchemy. +Classes provided include: + +:class:`~sqlalchemy.sql.compiler.SQLCompiler` - renders SQL +strings + +:class:`~sqlalchemy.sql.compiler.DDLCompiler` - renders DDL +(data definition language) strings + +:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders +type specification strings. + +To generate user-defined SQL strings, see +:module:`~sqlalchemy.ext.compiler`. + """ -import string, re +import re from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions, util as sql_util, visitors from sqlalchemy.sql import expression as sql @@ -58,40 +62,43 @@ BIND_TEMPLATES = { OPERATORS = { - operators.and_ : 'AND', - operators.or_ : 'OR', - operators.inv : 'NOT', - operators.add : '+', - operators.mul : '*', - operators.sub : '-', - operators.div : '/', - operators.mod : '%', - operators.truediv : '/', - operators.lt : '<', - operators.le : '<=', - operators.ne : '!=', - operators.gt : '>', - operators.ge : '>=', - operators.eq : '=', - operators.distinct_op : 'DISTINCT', - operators.concat_op : '||', - operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.between_op : 'BETWEEN', - operators.match_op : 'MATCH', - operators.in_op : 'IN', - operators.notin_op : 'NOT IN', + # binary + operators.and_ : ' AND ', + operators.or_ : ' OR ', + operators.add : ' + ', + operators.mul : ' * ', + operators.sub : ' - ', + # Py2K + operators.div : ' / ', + # end Py2K + operators.mod : ' % ', + operators.truediv : ' / ', + operators.lt : ' < ', + operators.le : ' <= ', + operators.ne : ' != ', + operators.gt : ' > ', + operators.ge : ' >= ', + operators.eq : ' = ', + operators.concat_op : ' || ', + operators.between_op : ' BETWEEN ', + operators.match_op : ' MATCH ', + operators.in_op : ' IN ', + operators.notin_op : ' NOT IN ', operators.comma_op : ', ', - operators.desc_op : 'DESC', - operators.asc_op : 'ASC', - operators.from_ : 'FROM', - operators.as_ : 'AS', - operators.exists : 'EXISTS', - operators.is_ : 'IS', - operators.isnot : 'IS NOT', - operators.collate : 'COLLATE', + operators.from_ : ' FROM ', + operators.as_ : ' AS ', + operators.is_ : ' IS ', + operators.isnot : ' IS NOT ', + operators.collate : ' COLLATE ', + + # unary + operators.exists : 'EXISTS ', + operators.distinct_op : 'DISTINCT ', + operators.inv : 'NOT ', + + # modifiers + operators.desc_op : ' DESC', + operators.asc_op : ' ASC', } FUNCTIONS = { @@ -140,7 +147,7 @@ class _CompileLabel(visitors.Visitable): def quote(self): return self.element.quote -class DefaultCompiler(engine.Compiled): +class SQLCompiler(engine.Compiled): """Default implementation of Compiled. Compiles ClauseElements into SQL strings. Uses a similar visit @@ -148,14 +155,14 @@ class DefaultCompiler(engine.Compiled): """ - operators = OPERATORS - functions = FUNCTIONS extract_map = EXTRACT_MAP - # if we are insert/update/delete. - # set to true when we visit an INSERT, UPDATE or DELETE + # class-level defaults which can be set at the instance + # level to define if this Compiled instance represents + # INSERT/UPDATE/DELETE isdelete = isinsert = isupdate = False - + returning = None + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -170,7 +177,9 @@ class DefaultCompiler(engine.Compiled): statement. """ - engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs) + engine.Compiled.__init__(self, dialect, statement, **kwargs) + + self.column_keys = column_keys # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) self.inline = inline or getattr(statement, 'inline', False) @@ -210,12 +219,6 @@ class DefaultCompiler(engine.Compiled): # or dialect.max_identifier_length self.truncated_names = {} - def compile(self): - self.string = self.process(self.statement) - - def process(self, obj, **kwargs): - return obj._compiler_dispatch(self, **kwargs) - def is_subquery(self): return len(self.stack) > 1 @@ -223,7 +226,6 @@ class DefaultCompiler(engine.Compiled): """return a dictionary of bind parameter keys and values""" if params: - params = util.column_dict(params) pd = {} for bindparam, name in self.bind_names.iteritems(): for paramname in (bindparam.key, bindparam.shortname, name): @@ -245,7 +247,10 @@ class DefaultCompiler(engine.Compiled): pd[self.bind_names[bindparam]] = bindparam.value return pd - params = property(construct_params) + params = property(construct_params, doc=""" + Return the bind params for this compiled object. + + """) def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -267,10 +272,11 @@ class DefaultCompiler(engine.Compiled): self._truncated_identifier("colident", label.name) or label.name if result_map is not None: - result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) + result_map[labelname.lower()] = \ + (label.name, (label, label.element, labelname), label.element.type) - return self.process(label.element) + " " + \ - self.operator_string(operators.as_) + " " + \ + return self.process(label.element) + \ + OPERATORS[operators.as_] + \ self.preparer.format_label(label, labelname) else: return self.process(label.element) @@ -292,14 +298,17 @@ class DefaultCompiler(engine.Compiled): return name else: if column.table.schema: - schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.' + schema_prefix = self.preparer.quote_schema( + column.table.schema, + column.table.quote_schema) + '.' else: schema_prefix = '' tablename = column.table.name tablename = isinstance(tablename, sql._generated_label) and \ self._truncated_identifier("alias", tablename) or tablename - return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name + return schema_prefix + \ + self.preparer.quote(tablename, column.table.quote) + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -314,7 +323,7 @@ class DefaultCompiler(engine.Compiled): return index.name def visit_typeclause(self, typeclause, **kwargs): - return typeclause.type.dialect_impl(self.dialect).get_col_spec() + return self.dialect.type_compiler.process(typeclause.type) def post_process_text(self, text): return text @@ -343,10 +352,8 @@ class DefaultCompiler(engine.Compiled): sep = clauselist.operator if sep is None: sep = " " - elif sep is operators.comma_op: - sep = ', ' else: - sep = " " + self.operator_string(clauselist.operator) + " " + sep = OPERATORS[clauselist.operator] return sep.join(s for s in (self.process(c) for c in clauselist.clauses) if s is not None) @@ -362,7 +369,8 @@ class DefaultCompiler(engine.Compiled): return x def visit_cast(self, cast, **kwargs): - return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) + return "CAST(%s AS %s)" % \ + (self.process(cast.clause), self.process(cast.typeclause)) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) @@ -372,26 +380,26 @@ class DefaultCompiler(engine.Compiled): if result_map is not None: result_map[func.name.lower()] = (func.name, None, func.type) - name = self.function_string(func) - - if util.callable(name): - return name(*[self.process(x) for x in func.clauses]) + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) else: - return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} + name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") + return ".".join(func.packagenames + [name]) % \ + {'expr':self.function_argspec(func, **kwargs)} def function_argspec(self, func, **kwargs): return self.process(func.clause_expr, **kwargs) - def function_string(self, func): - return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s")) - def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): entry = self.stack and self.stack[-1] or {} self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) - text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i) - for i, c in enumerate(cs.selects)), - " " + cs.keyword + " ") + text = (" " + cs.keyword + " ").join( + (self.process(c, asfrom=asfrom, parens=False, compound_index=i) + for i, c in enumerate(cs.selects)) + ) + group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by @@ -408,27 +416,57 @@ class DefaultCompiler(engine.Compiled): def visit_unary(self, unary, **kw): s = self.process(unary.element, **kw) if unary.operator: - s = self.operator_string(unary.operator) + " " + s + s = OPERATORS[unary.operator] + s if unary.modifier: - s = s + " " + self.operator_string(unary.modifier) + s = s + OPERATORS[unary.modifier] return s def visit_binary(self, binary, **kwargs): - op = self.operator_string(binary.operator) - if util.callable(op): - return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) - else: - return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + return self._operator_dispatch(binary.operator, + binary, + lambda opstr: self.process(binary.left) + opstr + self.process(binary.right), + **kwargs + ) - def operator_string(self, operator): - return self.operators.get(operator, str(operator)) + def visit_like_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + def visit_notlike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def _operator_dispatch(self, operator, element, fn, **kw): + if util.callable(operator): + disp = getattr(self, "visit_%s" % operator.__name__, None) + if disp: + return disp(element, **kw) + else: + return fn(OPERATORS[operator]) + else: + return fn(" " + operator + " ") + def visit_bindparam(self, bindparam, **kwargs): name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] if existing is not bindparam and (existing.unique or bindparam.unique): - raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + raise exc.CompileError( + "Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key + ) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) @@ -491,7 +529,7 @@ class DefaultCompiler(engine.Compiled): if isinstance(column, sql._Label): return column - if select.use_labels and column._label: + if select and select.use_labels and column._label: return _CompileLabel(column, column._label) if \ @@ -501,13 +539,15 @@ class DefaultCompiler(engine.Compiled): column.table is not None and \ not isinstance(column.table, sql.Select): return _CompileLabel(column, sql._generated_label(column.name)) - elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ + elif not isinstance(column, + (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ and (not hasattr(column, 'name') or isinstance(column, sql.Function)): return _CompileLabel(column, column.anon_label) else: return column - def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs): + def visit_select(self, select, asfrom=False, parens=True, + iswrapper=False, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -583,8 +623,10 @@ class DefaultCompiler(engine.Compiled): return text def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just before column list.""" - + """Called when building a ``SELECT`` statement, position is just before + column list. + + """ return select._distinct and "DISTINCT " or "" def order_by_clause(self, select): @@ -613,14 +655,16 @@ class DefaultCompiler(engine.Compiled): def visit_table(self, table, asfrom=False, **kwargs): if asfrom: if getattr(table, "schema", None): - return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) + return self.preparer.quote_schema(table.schema, table.quote_schema) + \ + "." + self.preparer.quote(table.name, table.quote) else: return self.preparer.quote(table.name, table.quote) else: return "" def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + return (self.process(join.left, asfrom=True) + \ + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) def visit_sequence(self, seq): @@ -629,41 +673,75 @@ class DefaultCompiler(engine.Compiled): def visit_insert(self, insert_stmt): self.isinsert = True colparams = self._get_colparams(insert_stmt) - preparer = self.preparer - - insert = ' '.join(["INSERT"] + - [self.process(x) for x in insert_stmt._prefixes]) - if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: + if not colparams and \ + not self.dialect.supports_default_values and \ + not self.dialect.supports_empty_insert: raise exc.CompileError( "The version of %s you are using does not support empty inserts." % self.dialect.name) - elif not colparams and self.dialect.supports_default_values: - return (insert + " INTO %s DEFAULT VALUES" % ( - (preparer.format_table(insert_stmt.table),))) - else: - return (insert + " INTO %s (%s) VALUES (%s)" % - (preparer.format_table(insert_stmt.table), - ', '.join([preparer.format_column(c[0]) - for c in colparams]), - ', '.join([c[1] for c in colparams]))) + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT" + + prefixes = [self.process(x) for x in insert_stmt._prefixes] + if prefixes: + text += " " + " ".join(prefixes) + + text += " INTO " + preparer.format_table(insert_stmt.table) + + if colparams or not supports_default_values: + text += " (%s)" % ', '.join([preparer.format_column(c[0]) + for c in colparams]) + + if self.returning or insert_stmt._returning: + self.returning = self.returning or insert_stmt._returning + returning_clause = self.returning_clause(insert_stmt, self.returning) + + # cheating + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + + if not colparams and supports_default_values: + text += " DEFAULT VALUES" + else: + text += " VALUES (%s)" % \ + ', '.join([c[1] for c in colparams]) + + if self.returning and returning_clause: + text += " " + returning_clause + + return text + def visit_update(self, update_stmt): self.stack.append({'from': set([update_stmt.table])}) self.isupdate = True colparams = self._get_colparams(update_stmt) - text = ' '.join(( - "UPDATE", - self.preparer.format_table(update_stmt.table), - 'SET', - ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] - for c in colparams) - )) - + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + + text += ' SET ' + \ + ', '.join( + self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] + for c in colparams + ) + + if update_stmt._returning: + self.returning = update_stmt._returning + returning_clause = self.returning_clause(update_stmt, update_stmt._returning) + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) + if self.returning and returning_clause: + text += " " + returning_clause + self.stack.pop(-1) return text @@ -681,7 +759,8 @@ class DefaultCompiler(engine.Compiled): self.postfetch = [] self.prefetch = [] - + self.returning = [] + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: @@ -701,6 +780,15 @@ class DefaultCompiler(engine.Compiled): # create a list of column assignment clauses as tuples values = [] + + need_pks = self.isinsert and \ + not self.inline and \ + not self.statement._returning + + implicit_returning = need_pks and \ + self.dialect.implicit_returning and \ + stmt.table.implicit_returning + for c in stmt.table.columns: if c.key in parameters: value = parameters[c.key] @@ -710,19 +798,48 @@ class DefaultCompiler(engine.Compiled): self.postfetch.append(c) value = self.process(value.self_group()) values.append((c, value)) + elif isinstance(c, schema.Column): if self.isinsert: - if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): - if (((isinstance(c.default, schema.Sequence) and - not c.default.optional) or - not self.dialect.supports_pk_autoincrement) or - (c.default is not None and - not isinstance(c.default, schema.Sequence))): - values.append((c, create_bind_param(c, None))) - self.prefetch.append(c) + if c.primary_key and \ + need_pks and \ + ( + c is not stmt.table._autoincrement_column or + not self.dialect.postfetch_lastrowid + ): + + if implicit_returning: + if isinstance(c.default, schema.Sequence): + proc = self.process(c.default) + if proc is not None: + values.append((c, proc)) + self.returning.append(c) + elif isinstance(c.default, schema.ColumnDefault) and \ + isinstance(c.default.arg, sql.ClauseElement): + values.append((c, self.process(c.default.arg.self_group()))) + self.returning.append(c) + elif c.default is not None: + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + else: + self.returning.append(c) + else: + if ( + c.default is not None and \ + ( + self.dialect.supports_sequences or + not isinstance(c.default, schema.Sequence) + ) + ) or \ + self.dialect.preexecute_autoincrement_sequences: + + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): values.append((c, self.process(c.default.arg.self_group()))) + if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) @@ -759,9 +876,19 @@ class DefaultCompiler(engine.Compiled): text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + if delete_stmt._returning: + self.returning = delete_stmt._returning + returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning) + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) + if self.returning and returning_clause: + text += " " + returning_clause + self.stack.pop(-1) return text @@ -775,110 +902,146 @@ class DefaultCompiler(engine.Compiled): def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - def __str__(self): - return self.string or '' - -class DDLBase(engine.SchemaIterator): - def find_alterables(self, tables): - alterables = [] - class FindAlterables(schema.SchemaVisitor): - def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and constraint.table in tables: - alterables.append(constraint) - findalterables = FindAlterables() - for table in tables: - for c in table.constraints: - findalterables.traverse(c) - return alterables - - def _validate_identifier(self, ident, truncate): - if truncate: - if len(ident) > self.dialect.max_identifier_length: - counter = getattr(self, 'counter', 0) - self.counter = counter + 1 - return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] - else: - return ident - else: - self.dialect.validate_identifier(ident) - return ident - - -class SchemaGenerator(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(SchemaGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables and set(tables) or None - self.preparer = dialect.identifier_preparer - self.dialect = dialect - def get_column_specification(self, column, first_pk=False): - raise NotImplementedError() +class DDLCompiler(engine.Compiled): + @property + def preparer(self): + return self.dialect.identifier_preparer - def _can_create(self, table): - self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) - return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + def construct_params(self, params=None): + return None + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.schema_item, schema.Table): + context = context.copy() + + preparer = self.dialect.identifier_preparer + path = preparer.format_table_seq(ddl.schema_item) + if len(path) == 1: + table, sch = path[0], '' + else: + table, sch = path[-1], path[0] - def visit_metadata(self, metadata): - if self.tables: - tables = self.tables - else: - tables = metadata.tables.values() - collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] - for table in collection: - self.traverse_single(table) - if self.dialect.supports_alter: - for alterable in self.find_alterables(collection): - self.add_foreignkey(alterable) - - def visit_table(self, table): - for listener in table.ddl_listeners['before-create']: - listener('before-create', table, self.connection) + context.setdefault('table', table) + context.setdefault('schema', sch) + context.setdefault('fullname', preparer.format_table(ddl.schema_item)) + + return ddl.statement % context - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + def visit_create_table(self, create): + table = create.element + preparer = self.dialect.identifier_preparer - self.append("\n" + " ".join(['CREATE'] + - table._prefixes + + text = "\n" + " ".join(['CREATE'] + \ + table._prefixes + \ ['TABLE', - self.preparer.format_table(table), - "("])) + preparer.format_table(table), + "("]) separator = "\n" # if only one primary key, specify it along with the column first_pk = False for column in table.columns: - self.append(separator) + text += separator separator = ", \n" - self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) + text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk) if column.primary_key: first_pk = True - for constraint in column.constraints: - self.traverse_single(constraint) + const = " ".join(self.process(constraint) for constraint in column.constraints) + if const: + text += " " + const # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if table.primary_key: - self.traverse_single(table.primary_key) - for constraint in [c for c in table.constraints if c is not table.primary_key]: - self.traverse_single(constraint) + text += ", \n\t" + self.process(table.primary_key) + + const = ", \n\t".join( + self.process(constraint) for constraint in table.constraints + if constraint is not table.primary_key + and constraint.inline_ddl + and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False)) + ) + if const: + text += ", \n\t" + const + + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def visit_drop_table(self, drop): + ret = "\nDROP TABLE " + self.preparer.format_table(drop.element) + if drop.cascade: + ret += " CASCADE CONSTRAINTS" + return ret + + def visit_create_index(self, create): + index = create.element + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + ', '.join(preparer.quote(c.name, c.quote) + for c in index.columns)) + return text - self.append("\n)%s\n\n" % self.post_create_table(table)) - self.execute() + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX " + \ + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - if hasattr(table, 'indexes'): - for index in table.indexes: - self.traverse_single(index) + def visit_add_constraint(self, create): + preparer = self.preparer + return "ALTER TABLE %s ADD %s" % ( + self.preparer.format_table(create.element.table), + self.process(create.element) + ) + + def visit_drop_constraint(self, drop): + preparer = self.preparer + return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( + self.preparer.format_table(drop.element.table), + self.preparer.format_constraint(drop.element), + " CASCADE" if drop.cascade else "" + ) + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + \ + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default - for listener in table.ddl_listeners['after-create']: - listener('after-create', table, self.connection) + if not column.nullable: + colspec += " NOT NULL" + return colspec def post_create_table(self, table): return '' + def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" + + compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) + compiler.compile() + return compiler + + def _validate_identifier(self, ident, truncate): + if truncate: + if len(ident) > self.dialect.max_identifier_length: + counter = getattr(self, 'counter', 0) + self.counter = counter + 1 + return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] + else: + return ident + else: + self.dialect.validate_identifier(ident) + return ident + def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): if isinstance(column.server_default.arg, basestring): @@ -888,149 +1051,190 @@ class SchemaGenerator(DDLBase): else: return None - def _compile(self, tocompile, parameters): - """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) - compiler.compile() - return compiler - def visit_check_constraint(self, constraint): - self.append(", \n\t") + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" CHECK (%s)" % constraint.sqltext) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text def visit_column_check_constraint(self, constraint): - self.append(" CHECK (%s)" % constraint.sqltext) - self.define_constraint_deferrability(constraint) + text = " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return - self.append(", \n\t") + return '' + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) - for c in constraint)) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += "PRIMARY KEY " + text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) + for c in constraint) + text += self.define_constraint_deferrability(constraint) + return text def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and self.dialect.supports_alter: - return - self.append(", \n\t ") - self.define_foreign_key(constraint) - - def add_foreignkey(self, constraint): - self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) - self.define_foreign_key(constraint) - self.execute() - - def define_foreign_key(self, constraint): - preparer = self.preparer + preparer = self.dialect.identifier_preparer + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - preparer.format_constraint(constraint)) - table = list(constraint.elements)[0].column.table - self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + text += "CONSTRAINT %s " % \ + preparer.format_constraint(constraint) + remote_table = list(constraint._elements.values())[0].column.table + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ', '.join(preparer.quote(f.parent.name, f.parent.quote) - for f in constraint.elements), - preparer.format_table(table), + for f in constraint._elements.values()), + preparer.format_table(remote_table), ', '.join(preparer.quote(f.column.name, f.column.quote) - for f in constraint.elements) - )) - if constraint.ondelete is not None: - self.append(" ON DELETE %s" % constraint.ondelete) - if constraint.onupdate is not None: - self.append(" ON UPDATE %s" % constraint.onupdate) - self.define_constraint_deferrability(constraint) + for f in constraint._elements.values()) + ) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text def visit_unique_constraint(self, constraint): - self.append(", \n\t") + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) + text += self.define_constraint_deferrability(constraint) + return text + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + if constraint.onupdate is not None: + text += " ON UPDATE %s" % constraint.onupdate + return text + def define_constraint_deferrability(self, constraint): + text = "" if constraint.deferrable is not None: if constraint.deferrable: - self.append(" DEFERRABLE") + text += " DEFERRABLE" else: - self.append(" NOT DEFERRABLE") + text += " NOT DEFERRABLE" if constraint.initially is not None: - self.append(" INITIALLY %s" % constraint.initially) + text += " INITIALLY %s" % constraint.initially + return text + + +class GenericTypeCompiler(engine.TypeCompiler): + def visit_CHAR(self, type_): + return "CHAR" + (type_.length and "(%d)" % type_.length or "") - def visit_column(self, column): - pass + def visit_NCHAR(self, type_): + return "NCHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_FLOAT(self, type_): + return "FLOAT" - def visit_index(self, index): - preparer = self.preparer - self.append("CREATE ") - if index.unique: - self.append("UNIQUE ") - self.append("INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - ', '.join(preparer.quote(c.name, c.quote) - for c in index.columns))) - self.execute() + def visit_NUMERIC(self, type_): + if type_.precision is None: + return "NUMERIC" + else: + return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale} + def visit_DECIMAL(self, type_): + return "DECIMAL" + + def visit_INTEGER(self, type_): + return "INTEGER" -class SchemaDropper(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(SchemaDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables - self.preparer = dialect.identifier_preparer - self.dialect = dialect + def visit_SMALLINT(self, type_): + return "SMALLINT" - def visit_metadata(self, metadata): - if self.tables: - tables = self.tables - else: - tables = metadata.tables.values() - collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] - if self.dialect.supports_alter: - for alterable in self.find_alterables(collection): - self.drop_foreignkey(alterable) - for table in collection: - self.traverse_single(table) - - def _can_drop(self, table): - self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) - return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) - - def visit_index(self, index): - self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)) - self.execute() - - def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( - self.preparer.format_table(constraint.table), - self.preparer.format_constraint(constraint))) - self.execute() - - def visit_table(self, table): - for listener in table.ddl_listeners['before-drop']: - listener('before-drop', table, self.connection) + def visit_BIGINT(self, type_): + return "BIGINT" - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_CLOB(self, type_): + return "CLOB" - self.append("\nDROP TABLE " + self.preparer.format_table(table)) - self.execute() + def visit_NCLOB(self, type_): + return "NCLOB" - for listener in table.ddl_listeners['after-drop']: - listener('after-drop', table, self.connection) + def visit_VARCHAR(self, type_): + return "VARCHAR" + (type_.length and "(%d)" % type_.length or "") + def visit_NVARCHAR(self, type_): + return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "") + def visit_BLOB(self, type_): + return "BLOB" + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + + def visit_TEXT(self, type_): + return "TEXT" + + def visit_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_boolean(self, type_): + return self.visit_BOOLEAN(type_) + + def visit_time(self, type_): + return self.visit_TIME(type_) + + def visit_datetime(self, type_): + return self.visit_DATETIME(type_) + + def visit_date(self, type_): + return self.visit_DATE(type_) + + def visit_big_integer(self, type_): + return self.visit_BIGINT(type_) + + def visit_small_integer(self, type_): + return self.visit_SMALLINT(type_) + + def visit_integer(self, type_): + return self.visit_INTEGER(type_) + + def visit_float(self, type_): + return self.visit_FLOAT(type_) + + def visit_numeric(self, type_): + return self.visit_NUMERIC(type_) + + def visit_string(self, type_): + return self.visit_VARCHAR(type_) + + def visit_unicode(self, type_): + return self.visit_VARCHAR(type_) + + def visit_text(self, type_): + return self.visit_TEXT(type_) + + def visit_unicode_text(self, type_): + return self.visit_TEXT(type_) + + def visit_null(self, type_): + raise NotImplementedError("Can't generate DDL for the null type") + + def visit_type_decorator(self, type_): + return self.process(type_.type_engine(self.dialect)) + + def visit_user_defined(self, type_): + return type_.get_col_spec() + class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" @@ -1176,24 +1380,24 @@ class IdentifierPreparer(object): else: return (self.format_table(table, use_schema=False), ) + @util.memoized_property + def _r_identifiers(self): + initial, final, escaped_final = \ + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] + r = re.compile( + r'(?:' + r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' + r'|([^\.]+))(?=\.|$))+' % + { 'initial': initial, + 'final': final, + 'escaped': escaped_final }) + return r + def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" - try: - r = self._r_identifiers - except AttributeError: - initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] - r = re.compile( - r'(?:' - r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' - r'|([^\.]+))(?=\.|$))+' % - { 'initial': initial, - 'final': final, - 'escaped': escaped_final }) - self._r_identifiers = r - + r = self._r_identifiers return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]] diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 83897ef05..91e0e74ae 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -29,10 +29,9 @@ to stay the same in future releases. import itertools, re from operator import attrgetter -from sqlalchemy import util, exc +from sqlalchemy import util, exc, types as sqltypes from sqlalchemy.sql import operators from sqlalchemy.sql.visitors import Visitable, cloned_traverse -from sqlalchemy import types as sqltypes import operator functions, schema, sql_util = None, None, None @@ -128,7 +127,8 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): Similar functionality is also available via the ``select()`` method on any :class:`~sqlalchemy.sql.expression.FromClause`. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.Select`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.Select`. All arguments which accept ``ClauseElement`` arguments also accept string arguments, which will be converted as appropriate into @@ -241,7 +241,8 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): """ if 'scalar' in kwargs: - util.warn_deprecated('scalar option is deprecated; see docs for details') + util.warn_deprecated( + 'scalar option is deprecated; see docs for details') scalar = kwargs.pop('scalar', False) s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) if scalar: @@ -250,15 +251,16 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): return s def subquery(alias, *args, **kwargs): - """Return an :class:`~sqlalchemy.sql.expression.Alias` object derived from a :class:`~sqlalchemy.sql.expression.Select`. + """Return an :class:`~sqlalchemy.sql.expression.Alias` object derived + from a :class:`~sqlalchemy.sql.expression.Select`. name alias name \*args, \**kwargs - all other arguments are delivered to the :func:`~sqlalchemy.sql.expression.select` - function. + all other arguments are delivered to the + :func:`~sqlalchemy.sql.expression.select` function. """ return Select(*args, **kwargs).alias(alias) @@ -280,12 +282,12 @@ def insert(table, values=None, inline=False, **kwargs): table columns. Note that the :meth:`~Insert.values()` generative method may also be used for this. - :param prefixes: A list of modifier keywords to be inserted between INSERT and INTO. - Alternatively, the :meth:`~Insert.prefix_with` generative method may be used. + :param prefixes: A list of modifier keywords to be inserted between INSERT + and INTO. Alternatively, the :meth:`~Insert.prefix_with` generative method + may be used. - :param inline: - if True, SQL defaults will be compiled 'inline' into the statement - and not pre-executed. + :param inline: if True, SQL defaults will be compiled 'inline' into the + statement and not pre-executed. If both `values` and compile-time bind parameters are present, the compile-time bind parameters override the information specified @@ -313,9 +315,9 @@ def update(table, whereclause=None, values=None, inline=False, **kwargs): :param table: The table to be updated. - :param whereclause: A ``ClauseElement`` describing the ``WHERE`` condition of the - ``UPDATE`` statement. Note that the :meth:`~Update.where()` generative - method may also be used for this. + :param whereclause: A ``ClauseElement`` describing the ``WHERE`` condition + of the ``UPDATE`` statement. Note that the :meth:`~Update.where()` + generative method may also be used for this. :param values: A dictionary which specifies the ``SET`` conditions of the @@ -347,7 +349,12 @@ def update(table, whereclause=None, values=None, inline=False, **kwargs): against the ``UPDATE`` statement. """ - return Update(table, whereclause=whereclause, values=values, inline=inline, **kwargs) + return Update( + table, + whereclause=whereclause, + values=values, + inline=inline, + **kwargs) def delete(table, whereclause = None, **kwargs): """Return a :class:`~sqlalchemy.sql.expression.Delete` clause element. @@ -357,9 +364,9 @@ def delete(table, whereclause = None, **kwargs): :param table: The table to be updated. - :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` condition of the - ``UPDATE`` statement. Note that the :meth:`~Delete.where()` generative method - may be used instead. + :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` + condition of the ``UPDATE`` statement. Note that the :meth:`~Delete.where()` + generative method may be used instead. """ return Delete(table, whereclause, **kwargs) @@ -368,8 +375,8 @@ def and_(*clauses): """Join a list of clauses together using the ``AND`` operator. The ``&`` operator is also overloaded on all - :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same - result. + :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the + same result. """ if len(clauses) == 1: @@ -380,8 +387,8 @@ def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. The ``|`` operator is also overloaded on all - :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same - result. + :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the + same result. """ if len(clauses) == 1: @@ -392,8 +399,8 @@ def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. The ``~`` operator is also overloaded on all - :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same - result. + :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the + same result. """ return operators.inv(_literal_as_binds(clause)) @@ -408,8 +415,9 @@ def between(ctest, cleft, cright): Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``. - The ``between()`` method on all :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses - provides similar functionality. + The ``between()`` method on all + :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses provides + similar functionality. """ ctest = _literal_as_binds(ctest) @@ -517,7 +525,8 @@ def exists(*args, **kwargs): def union(*selects, **kwargs): """Return a ``UNION`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. A similar ``union()`` method is available on all :class:`~sqlalchemy.sql.expression.FromClause` subclasses. @@ -535,7 +544,8 @@ def union(*selects, **kwargs): def union_all(*selects, **kwargs): """Return a ``UNION ALL`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. A similar ``union_all()`` method is available on all :class:`~sqlalchemy.sql.expression.FromClause` subclasses. @@ -553,7 +563,8 @@ def union_all(*selects, **kwargs): def except_(*selects, **kwargs): """Return an ``EXCEPT`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. \*selects a list of :class:`~sqlalchemy.sql.expression.Select` instances. @@ -568,7 +579,8 @@ def except_(*selects, **kwargs): def except_all(*selects, **kwargs): """Return an ``EXCEPT ALL`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. \*selects a list of :class:`~sqlalchemy.sql.expression.Select` instances. @@ -583,7 +595,8 @@ def except_all(*selects, **kwargs): def intersect(*selects, **kwargs): """Return an ``INTERSECT`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. \*selects a list of :class:`~sqlalchemy.sql.expression.Select` instances. @@ -598,7 +611,8 @@ def intersect(*selects, **kwargs): def intersect_all(*selects, **kwargs): """Return an ``INTERSECT ALL`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. \*selects a list of :class:`~sqlalchemy.sql.expression.Select` instances. @@ -613,8 +627,8 @@ def intersect_all(*selects, **kwargs): def alias(selectable, alias=None): """Return an :class:`~sqlalchemy.sql.expression.Alias` object. - An ``Alias`` represents any :class:`~sqlalchemy.sql.expression.FromClause` with - an alternate name assigned within SQL, typically using the ``AS`` + An ``Alias`` represents any :class:`~sqlalchemy.sql.expression.FromClause` + with an alternate name assigned within SQL, typically using the ``AS`` clause when generated, e.g. ``SELECT * FROM table AS aliasname``. Similar functionality is available via the ``alias()`` method @@ -656,7 +670,8 @@ def literal(value, type_=None): return _BindParamClause(None, value, type_=type_, unique=True) def label(name, obj): - """Return a :class:`~sqlalchemy.sql.expression._Label` object for the given :class:`~sqlalchemy.sql.expression.ColumnElement`. + """Return a :class:`~sqlalchemy.sql.expression._Label` object for the given + :class:`~sqlalchemy.sql.expression.ColumnElement`. A label changes the name of an element in the columns clause of a ``SELECT`` statement, typically via the ``AS`` SQL keyword. @@ -674,11 +689,13 @@ def label(name, obj): return _Label(name, obj) def column(text, type_=None): - """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. + """Return a textual column clause, as would be in the columns clause of a + ``SELECT`` statement. - The object returned is an instance of :class:`~sqlalchemy.sql.expression.ColumnClause`, - which represents the "syntactical" portion of the schema-level - :class:`~sqlalchemy.schema.Column` object. + The object returned is an instance of + :class:`~sqlalchemy.sql.expression.ColumnClause`, which represents the + "syntactical" portion of the schema-level :class:`~sqlalchemy.schema.Column` + object. text the name of the column. Quoting rules will be applied to the @@ -710,9 +727,9 @@ def literal_column(text, type_=None): :func:`~sqlalchemy.sql.expression.column` function. type\_ - an optional :class:`~sqlalchemy.types.TypeEngine` object which will provide - result-set translation and additional expression semantics for this - column. If left as None the type will be NullType. + an optional :class:`~sqlalchemy.types.TypeEngine` object which will + provide result-set translation and additional expression semantics for + this column. If left as None the type will be NullType. """ return ColumnClause(text, type_=type_, is_literal=True) @@ -752,7 +769,8 @@ def bindparam(key, value=None, shortname=None, type_=None, unique=False): return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname) def outparam(key, type_=None): - """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them. + """Create an 'OUT' parameter for usage in functions (stored procedures), for + databases which support them. The ``outparam`` can be used like a regular function parameter. The "output" value will be available from the @@ -760,7 +778,8 @@ def outparam(key, type_=None): attribute, which returns a dictionary containing the values. """ - return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True) + return _BindParamClause( + key, None, type_=type_, unique=False, isoutparam=True) def text(text, bind=None, *args, **kwargs): """Create literal text to be inserted into a query. @@ -803,8 +822,10 @@ def text(text, bind=None, *args, **kwargs): return _TextClause(text, bind=bind, *args, **kwargs) def null(): - """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql statement.""" - + """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql + statement. + + """ return _Null() class _FunctionGenerator(object): @@ -839,7 +860,8 @@ class _FunctionGenerator(object): if func is not None: return func(*c, **o) - return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) + return Function( + self.__names[-1], packagenames=self.__names[0:-1], *c, **o) # "func" global - i.e. func.count() func = _FunctionGenerator() @@ -861,10 +883,19 @@ def _clone(element): return element._clone() def _expand_cloned(elements): - """expand the given set of ClauseElements to be the set of all 'cloned' predecessors.""" - + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ return itertools.chain(*[x._cloned_set for x in elements]) +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain(*[c._select_iterable for c in elements]) + def _cloned_intersection(a, b): """return the intersection of sets a and b, counting any overlap between 'cloned' predecessors. @@ -879,7 +910,8 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__') + return not isinstance(element, Visitable) and \ + not hasattr(element, '__clause_element__') def _from_objects(*elements): return itertools.chain(*[element._from_objects for element in elements]) @@ -940,27 +972,36 @@ def _no_literals(element): return element def _corresponding_column_or_error(fromclause, column, require_embedded=False): - c = fromclause.corresponding_column(column, require_embedded=require_embedded) + c = fromclause.corresponding_column(column, + require_embedded=require_embedded) if not c: - raise exc.InvalidRequestError("Given column '%s', attached to table '%s', " + raise exc.InvalidRequestError( + "Given column '%s', attached to table '%s', " "failed to locate a corresponding column from table '%s'" - % (column, getattr(column, 'table', None), fromclause.description)) + % + (column, + getattr(column, 'table', None),fromclause.description) + ) return c def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" + return isinstance(col, ColumnElement) class ClauseElement(Visitable): - """Base class for elements of a programmatically constructed SQL expression.""" - + """Base class for elements of a programmatically constructed SQL + expression. + + """ __visit_name__ = 'clause' _annotations = {} supports_execution = False _from_objects = [] - + _bind = None + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -984,7 +1025,8 @@ class ClauseElement(Visitable): @util.memoized_property def _cloned_set(self): - """Return the set consisting all cloned anscestors of this ClauseElement. + """Return the set consisting all cloned anscestors of this + ClauseElement. Includes this ClauseElement. This accessor tends to be used for FromClause objects to identify 'equivalent' FROM clauses, regardless @@ -1004,15 +1046,20 @@ class ClauseElement(Visitable): return d def _annotate(self, values): - """return a copy of this ClauseElement with the given annotations dictionary.""" - + """return a copy of this ClauseElement with the given annotations + dictionary. + + """ global Annotated if Annotated is None: from sqlalchemy.sql.util import Annotated return Annotated(self, values) def _deannotate(self): - """return a copy of this ClauseElement with an empty annotations dictionary.""" + """return a copy of this ClauseElement with an empty annotations + dictionary. + + """ return self._clone() def unique_params(self, *optionaldict, **kwargs): @@ -1044,7 +1091,8 @@ class ClauseElement(Visitable): if len(optionaldict) == 1: kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: - raise exc.ArgumentError("params() takes zero or one positional dictionary argument") + raise exc.ArgumentError( + "params() takes zero or one positional dictionary argument") def visit_bindparam(bind): if bind.key in kwargs: @@ -1088,15 +1136,20 @@ class ClauseElement(Visitable): def self_group(self, against=None): return self + # TODO: remove .bind as a method from the root ClauseElement. + # we should only be deriving binds from FromClause elements + # and certain SchemaItem subclasses. + # the "search_for_bind" functionality can still be used by + # execute(), however. @property def bind(self): - """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" + """Returns the Engine or Connection to which this ClauseElement is + bound, or None if none found. + + """ + if self._bind is not None: + return self._bind - try: - if self._bind is not None: - return self._bind - except AttributeError: - pass for f in _from_objects(self): if f is self: continue @@ -1121,68 +1174,82 @@ class ClauseElement(Visitable): return e._execute_clauseelement(self, multiparams, params) def scalar(self, *multiparams, **params): - """Compile and execute this ``ClauseElement``, returning the result's scalar representation.""" - + """Compile and execute this ``ClauseElement``, returning the result's + scalar representation. + + """ return self.execute(*multiparams, **params).scalar() - def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False): + def compile(self, bind=None, dialect=None, **kw): """Compile this SQL expression. The return value is a :class:`~sqlalchemy.engine.Compiled` object. - Calling `str()` or `unicode()` on the returned value will yield - a string representation of the result. The :class:`~sqlalchemy.engine.Compiled` - object also can return a dictionary of bind parameter names and - values using the `params` accessor. + Calling `str()` or `unicode()` on the returned value will yield a string + representation of the result. The :class:`~sqlalchemy.engine.Compiled` + object also can return a dictionary of bind parameter names and values + using the `params` accessor. :param bind: An ``Engine`` or ``Connection`` from which a ``Compiled`` will be acquired. This argument takes precedence over this ``ClauseElement``'s bound engine, if any. - :param column_keys: Used for INSERT and UPDATE statements, a list of - column names which should be present in the VALUES clause - of the compiled statement. If ``None``, all columns - from the target table object are rendered. - - :param compiler: A ``Compiled`` instance which will be used to compile - this expression. This argument takes precedence - over the `bind` and `dialect` arguments as well as - this ``ClauseElement``'s bound engine, if - any. - :param dialect: A ``Dialect`` instance frmo which a ``Compiled`` will be acquired. This argument takes precedence over the `bind` argument as well as this ``ClauseElement``'s bound engine, if any. - :param inline: Used for INSERT statements, for a dialect which does - not support inline retrieval of newly generated - primary key columns, will force the expression used - to create the new primary key value to be rendered - inline within the INSERT statement's VALUES clause. - This typically refers to Sequence execution but - may also refer to any server-side default generation - function associated with a primary key `Column`. + \**kw + + Keyword arguments are passed along to the compiler, + which can affect the string produced. + + Keywords for a statement compiler are: + + column_keys + Used for INSERT and UPDATE statements, a list of + column names which should be present in the VALUES clause + of the compiled statement. If ``None``, all columns + from the target table object are rendered. + + inline + Used for INSERT statements, for a dialect which does + not support inline retrieval of newly generated + primary key columns, will force the expression used + to create the new primary key value to be rendered + inline within the INSERT statement's VALUES clause. + This typically refers to Sequence execution but + may also refer to any server-side default generation + function associated with a primary key `Column`. """ - if compiler is None: - if dialect is not None: - compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) - elif bind is not None: - compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline) - elif self.bind is not None: - compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline) + + if not dialect: + if bind: + dialect = bind.dialect + elif self.bind: + dialect = self.bind.dialect + bind = self.bind else: global DefaultDialect if DefaultDialect is None: from sqlalchemy.engine.default import DefaultDialect dialect = DefaultDialect() - compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) + compiler = self._compiler(dialect, bind=bind, **kw) compiler.compile() return compiler - + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + + return dialect.statement_compiler(dialect, self, **kw) + def __str__(self): + # Py3K + #return unicode(self.compile()) + # Py2K return unicode(self.compile()).encode('ascii', 'backslashreplace') + # end Py2K def __and__(self, other): return and_(self, other) @@ -1193,11 +1260,25 @@ class ClauseElement(Visitable): def __invert__(self): return self._negate() + if util.jython: + def __hash__(self): + """Return a distinct hash code. + + ClauseElements may have special equality comparisons which + makes us rely on them having unique hash codes for use in + hash-based collections. Stock __hash__ doesn't guarantee + unique values on platforms with moving GCs. + """ + return id(self) + def _negate(self): if hasattr(self, 'negation_clause'): return self.negation_clause else: - return _UnaryExpression(self.self_group(against=operators.inv), operator=operators.inv, negate=None) + return _UnaryExpression( + self.self_group(against=operators.inv), + operator=operators.inv, + negate=None) def __repr__(self): friendly = getattr(self, 'description', None) @@ -1211,6 +1292,12 @@ class ClauseElement(Visitable): class _Immutable(object): """mark a ClauseElement as 'immutable' when expressions are cloned.""" + def unique_params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + def _clone(self): return self @@ -1330,6 +1417,9 @@ class ColumnOperators(Operators): def __truediv__(self, other): return self.operate(operators.truediv, other) + def __rtruediv__(self, other): + return self.reverse_operate(operators.truediv, other) + class _CompareMixin(ColumnOperators): """Defines comparison and math operations for ``ClauseElement`` instances.""" @@ -1365,7 +1455,9 @@ class _CompareMixin(ColumnOperators): operators.add : (__operate,), operators.mul : (__operate,), operators.sub : (__operate,), + # Py2K operators.div : (__operate,), + # end Py2K operators.mod : (__operate,), operators.truediv : (__operate,), operators.lt : (__compare, operators.ge), @@ -1632,7 +1724,7 @@ class ColumnCollection(util.OrderedProperties): def __init__(self, *cols): super(ColumnCollection, self).__init__() - [self.add(c) for c in cols] + self.update((c.key, c) for c in cols) def __str__(self): return repr([str(c) for c in self]) @@ -1734,8 +1826,10 @@ class Selectable(ClauseElement): __visit_name__ = 'selectable' class FromClause(Selectable): - """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement.""" - + """Represent an element that can be used within the ``FROM`` + clause of a ``SELECT`` statement. + + """ __visit_name__ = 'fromclause' named_with_column = False _hide_froms = [] @@ -1749,7 +1843,11 @@ class FromClause(Selectable): col = list(self.primary_key)[0] else: col = list(self.columns)[0] - return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) + return select( + [func.count(col).label('tbl_row_count')], + whereclause, + from_obj=[self], + **params) def select(self, whereclause=None, **params): """return a SELECT of this ``FromClause``.""" @@ -1794,8 +1892,10 @@ class FromClause(Selectable): return fromclause in self._cloned_set def replace_selectable(self, old, alias): - """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" - + """replace all occurences of FromClause 'old' with the given Alias + object, returning a copy of this ``FromClause``. + + """ global ClauseAdapter if ClauseAdapter is None: from sqlalchemy.sql.util import ClauseAdapter @@ -1846,24 +1946,30 @@ class FromClause(Selectable): col, intersect = c, i elif len(i) > len(intersect): # 'c' has a larger field of correspondence than 'col'. - # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches a1.c.x->table.c.x better than + # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches + # a1.c.x->table.c.x better than # selectable.c.x->table.c.x does. col, intersect = c, i elif i == intersect: # they have the same field of correspondence. - # see which proxy_set has fewer columns in it, which indicates a - # closer relationship with the root column. Also take into account the - # "weight" attribute which CompoundSelect() uses to give higher precedence to - # columns based on vertical position in the compound statement, and discard columns - # that have no reference to the target column (also occurs with CompoundSelect) + # see which proxy_set has fewer columns in it, which indicates + # a closer relationship with the root column. Also take into + # account the "weight" attribute which CompoundSelect() uses to + # give higher precedence to columns based on vertical position + # in the compound statement, and discard columns that have no + # reference to the target column (also occurs with + # CompoundSelect) col_distance = util.reduce(operator.add, - [sc._annotations.get('weight', 1) for sc in col.proxy_set if sc.shares_lineage(column)] + [sc._annotations.get('weight', 1) + for sc in col.proxy_set + if sc.shares_lineage(column)] ) c_distance = util.reduce(operator.add, - [sc._annotations.get('weight', 1) for sc in c.proxy_set if sc.shares_lineage(column)] + [sc._annotations.get('weight', 1) + for sc in c.proxy_set + if sc.shares_lineage(column)] ) - if \ - c_distance < col_distance: + if c_distance < col_distance: col, intersect = c, i return col @@ -2011,7 +2117,9 @@ class _BindParamClause(ColumnElement): the same type. """ - return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ and self.value == other.value + return isinstance(other, _BindParamClause) and \ + other.type.__class__ == self.type.__class__ and \ + self.value == other.value def __getstate__(self): """execute a deferred value for serialization purposes.""" @@ -2024,7 +2132,9 @@ class _BindParamClause(ColumnElement): return d def __repr__(self): - return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type)) + return "_BindParamClause(%r, %r, type_=%r)" % ( + self.key, self.value, self.type + ) class _TypeClause(ClauseElement): """Handle a type keyword in a SQL statement. @@ -2057,7 +2167,8 @@ class _TextClause(ClauseElement): _hide_froms = [] - def __init__(self, text = "", bind=None, bindparams=None, typemap=None, autocommit=False): + def __init__(self, text = "", bind=None, + bindparams=None, typemap=None, autocommit=False): self._bind = bind self.bindparams = {} self.typemap = typemap @@ -2157,7 +2268,8 @@ class ClauseList(ClauseElement): return list(itertools.chain(*[c._from_objects for c in self.clauses])) def self_group(self, against=None): - if self.group and self.operator is not against and operators.is_precedent(self.operator, against): + if self.group and self.operator is not against and \ + operators.is_precedent(self.operator, against): return _Grouping(self) else: return self @@ -2200,9 +2312,13 @@ class _Case(ColumnElement): pass if value: - whenlist = [(_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens] + whenlist = [ + (_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens + ] else: - whenlist = [(_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens] + whenlist = [ + (_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens + ] if whenlist: type_ = list(whenlist[-1])[-1].type @@ -2472,16 +2588,19 @@ class _Exists(_UnaryExpression): return e def select_from(self, clause): - """return a new exists() construct with the given expression set as its FROM clause.""" - + """return a new exists() construct with the given expression set as its FROM + clause. + + """ e = self._clone() e.element = self.element.select_from(clause).self_group() return e def where(self, clause): - """return a new exists() construct with the given expression added to its WHERE clause, joined - to the existing clause via AND, if any.""" - + """return a new exists() construct with the given expression added to its WHERE + clause, joined to the existing clause via AND, if any. + + """ e = self._clone() e.element = self.element.where(clause).self_group() return e @@ -2517,7 +2636,9 @@ class Join(FromClause): id(self.right)) def is_derived_from(self, fromclause): - return fromclause is self or self.left.is_derived_from(fromclause) or self.right.is_derived_from(fromclause) + return fromclause is self or \ + self.left.is_derived_from(fromclause) or\ + self.right.is_derived_from(fromclause) def self_group(self, against=None): return _FromGrouping(self) @@ -2634,7 +2755,11 @@ class Alias(FromClause): @property def description(self): + # Py3K + #return self.name + # Py2K return self.name.encode('ascii', 'backslashreplace') + # end Py2K def as_scalar(self): try: @@ -2762,14 +2887,19 @@ class _Label(ColumnElement): def __init__(self, name, element, type_=None): while isinstance(element, _Label): element = element.element - self.name = self.key = self._label = name or _generated_label("%%(%d %s)s" % (id(self), getattr(element, 'name', 'anon'))) + self.name = self.key = self._label = name or \ + _generated_label("%%(%d %s)s" % ( + id(self), getattr(element, 'name', 'anon')) + ) self._element = element self._type = type_ self.quote = element.quote @util.memoized_property def type(self): - return sqltypes.to_instance(self._type or getattr(self._element, 'type', None)) + return sqltypes.to_instance( + self._type or getattr(self._element, 'type', None) + ) @util.memoized_property def element(self): @@ -2842,7 +2972,11 @@ class ColumnClause(_Immutable, ColumnElement): @util.memoized_property def description(self): + # Py3K + #return self.name + # Py2K return self.name.encode('ascii', 'backslashreplace') + # end Py2K @util.memoized_property def _label(self): @@ -2891,7 +3025,12 @@ class ColumnClause(_Immutable, ColumnElement): # propagate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) - c = ColumnClause(name or self.name, selectable=selectable, type_=self.type, is_literal=is_literal) + c = ColumnClause( + name or self.name, + selectable=selectable, + type_=self.type, + is_literal=is_literal + ) c.proxies = [self] if attach: selectable.columns[c.name] = c @@ -2927,7 +3066,11 @@ class TableClause(_Immutable, FromClause): @util.memoized_property def description(self): + # Py3K + #return self.name + # Py2K return self.name.encode('ascii', 'backslashreplace') + # end Py2K def append_column(self, c): self._columns[c.name] = c @@ -2944,7 +3087,11 @@ class TableClause(_Immutable, FromClause): col = list(self.primary_key)[0] else: col = list(self.columns)[0] - return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) + return select( + [func.count(col).label('tbl_row_count')], + whereclause, + from_obj=[self], + **params) def insert(self, values=None, inline=False, **kwargs): """Generate an :func:`~sqlalchemy.sql.expression.insert()` construct.""" @@ -2954,7 +3101,8 @@ class TableClause(_Immutable, FromClause): def update(self, whereclause=None, values=None, inline=False, **kwargs): """Generate an :func:`~sqlalchemy.sql.expression.update()` construct.""" - return update(self, whereclause=whereclause, values=values, inline=inline, **kwargs) + return update(self, whereclause=whereclause, + values=values, inline=inline, **kwargs) def delete(self, whereclause=None, **kwargs): """Generate a :func:`~sqlalchemy.sql.expression.delete()` construct.""" @@ -3004,7 +3152,8 @@ class _SelectBaseMixin(object): Typically, a select statement which has only one column in its columns clause is eligible to be used as a scalar expression. - The returned object is an instance of :class:`~sqlalchemy.sql.expression._ScalarSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression._ScalarSelect`. """ return _ScalarSelect(self) @@ -3013,10 +3162,10 @@ class _SelectBaseMixin(object): def apply_labels(self): """return a new selectable with the 'use_labels' flag set to True. - This will result in column expressions being generated using labels against their table - name, such as "SELECT somecolumn AS tablename_somecolumn". This allows selectables which - contain multiple FROM clauses to produce a unique set of column names regardless of name conflicts - among the individual FROM clauses. + This will result in column expressions being generated using labels against their + table name, such as "SELECT somecolumn AS tablename_somecolumn". This allows + selectables which contain multiple FROM clauses to produce a unique set of column + names regardless of name conflicts among the individual FROM clauses. """ self.use_labels = True @@ -3127,7 +3276,8 @@ class _ScalarSelect(_Grouping): return list(self.inner_columns)[0]._make_proxy(selectable, name) class CompoundSelect(_SelectBaseMixin, FromClause): - """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations.""" + """Forms the basis of ``UNION``, ``UNION ALL``, and other + SELECT-based set operations.""" __visit_name__ = 'compound_select' @@ -3147,7 +3297,8 @@ class CompoundSelect(_SelectBaseMixin, FromClause): elif len(s.c) != numcols: raise exc.ArgumentError( "All selectables passed to CompoundSelect must " - "have identical numbers of columns; select #%d has %d columns, select #%d has %d" % + "have identical numbers of columns; select #%d has %d columns," + " select #%d has %d" % (1, len(self.selects[0].c), n+1, len(s.c)) ) @@ -3222,7 +3373,15 @@ class Select(_SelectBaseMixin, FromClause): __visit_name__ = 'select' - def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs): + def __init__(self, + columns, + whereclause=None, + from_obj=None, + distinct=False, + having=None, + correlate=True, + prefixes=None, + **kwargs): """Construct a Select object. The public constructor for Select is the @@ -3241,9 +3400,9 @@ class Select(_SelectBaseMixin, FromClause): if columns: self._raw_columns = [ - isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c - for c in - [_literal_as_column(c) for c in columns] + isinstance(c, _ScalarSelect) and + c.self_group(against=operators.comma_op) or c + for c in [_literal_as_column(c) for c in columns] ] self._froms.update(_from_objects(*self._raw_columns)) @@ -3331,8 +3490,7 @@ class Select(_SelectBaseMixin, FromClause): be rendered into the columns clause of the resulting SELECT statement. """ - - return itertools.chain(*[c._select_iterable for c in self._raw_columns]) + return _select_iterables(self._raw_columns) def is_derived_from(self, fromclause): if self in fromclause._cloned_set: @@ -3347,7 +3505,7 @@ class Select(_SelectBaseMixin, FromClause): self._reset_exported() from_cloned = dict((f, clone(f)) for f in self._froms.union(self._correlate)) - self._froms = set(from_cloned[f] for f in self._froms) + self._froms = util.OrderedSet(from_cloned[f] for f in self._froms) self._correlate = set(from_cloned[f] for f in self._correlate) self._raw_columns = [clone(c) for c in self._raw_columns] for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): @@ -3359,11 +3517,17 @@ class Select(_SelectBaseMixin, FromClause): return (column_collections and list(self.columns) or []) + \ self._raw_columns + list(self._froms) + \ - [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] + [x for x in + (self._whereclause, self._having, + self._order_by_clause, self._group_by_clause) + if x is not None] @_generative def column(self, column): - """return a new select() construct with the given column expression added to its columns clause.""" + """return a new select() construct with the given column expression + added to its columns clause. + + """ column = _literal_as_column(column) @@ -3375,63 +3539,73 @@ class Select(_SelectBaseMixin, FromClause): @_generative def with_only_columns(self, columns): - """return a new select() construct with its columns clause replaced with the given columns.""" + """return a new select() construct with its columns clause replaced + with the given columns. + + """ self._raw_columns = [ - isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c - for c in - [_literal_as_column(c) for c in columns] + isinstance(c, _ScalarSelect) and + c.self_group(against=operators.comma_op) or c + for c in [_literal_as_column(c) for c in columns] ] @_generative def where(self, whereclause): - """return a new select() construct with the given expression added to its WHERE clause, joined - to the existing clause via AND, if any.""" + """return a new select() construct with the given expression added to its + WHERE clause, joined to the existing clause via AND, if any. + + """ self.append_whereclause(whereclause) @_generative def having(self, having): - """return a new select() construct with the given expression added to its HAVING clause, joined - to the existing clause via AND, if any.""" - + """return a new select() construct with the given expression added to its HAVING + clause, joined to the existing clause via AND, if any. + + """ self.append_having(having) @_generative def distinct(self): - """return a new select() construct which will apply DISTINCT to its columns clause.""" - + """return a new select() construct which will apply DISTINCT to its columns + clause. + + """ self._distinct = True @_generative def prefix_with(self, clause): - """return a new select() construct which will apply the given expression to the start of its - columns clause, not using any commas.""" + """return a new select() construct which will apply the given expression to the + start of its columns clause, not using any commas. + """ clause = _literal_as_text(clause) self._prefixes = self._prefixes + [clause] @_generative def select_from(self, fromclause): - """return a new select() construct with the given FROM expression applied to its list of - FROM objects.""" + """return a new select() construct with the given FROM expression applied to its + list of FROM objects. + """ fromclause = _literal_as_text(fromclause) self._froms = self._froms.union([fromclause]) @_generative def correlate(self, *fromclauses): - """return a new select() construct which will correlate the given FROM clauses to that - of an enclosing select(), if a match is found. - - By "match", the given fromclause must be present in this select's list of FROM objects - and also present in an enclosing select's list of FROM objects. - - Calling this method turns off the select's default behavior of "auto-correlation". Normally, - select() auto-correlates all of its FROM clauses to those of an embedded select when - compiled. - - If the fromclause is None, correlation is disabled for the returned select(). + """return a new select() construct which will correlate the given FROM clauses to + that of an enclosing select(), if a match is found. + + By "match", the given fromclause must be present in this select's list of FROM + objects and also present in an enclosing select's list of FROM objects. + + Calling this method turns off the select's default behavior of + "auto-correlation". Normally, select() auto-correlates all of its FROM clauses to + those of an embedded select when compiled. + + If the fromclause is None, correlation is disabled for the returned select(). """ self._should_correlate = False @@ -3447,8 +3621,10 @@ class Select(_SelectBaseMixin, FromClause): self._correlate = self._correlate.union([fromclause]) def append_column(self, column): - """append the given column expression to the columns clause of this select() construct.""" - + """append the given column expression to the columns clause of this select() + construct. + + """ column = _literal_as_column(column) if isinstance(column, _ScalarSelect): @@ -3459,8 +3635,10 @@ class Select(_SelectBaseMixin, FromClause): self._reset_exported() def append_prefix(self, clause): - """append the given columns clause prefix expression to this select() construct.""" - + """append the given columns clause prefix expression to this select() + construct. + + """ clause = _literal_as_text(clause) self._prefixes = self._prefixes.union([clause]) @@ -3490,7 +3668,8 @@ class Select(_SelectBaseMixin, FromClause): self._having = _literal_as_text(having) def append_from(self, fromclause): - """append the given FromClause expression to this select() construct's FROM clause. + """append the given FromClause expression to this select() construct's FROM + clause. """ if _is_literal(fromclause): @@ -3529,8 +3708,10 @@ class Select(_SelectBaseMixin, FromClause): return union(self, other, **kwargs) def union_all(self, other, **kwargs): - """return a SQL UNION ALL of this select() construct against the given selectable.""" - + """return a SQL UNION ALL of this select() construct against the given + selectable. + + """ return union_all(self, other, **kwargs) def except_(self, other, **kwargs): @@ -3539,18 +3720,24 @@ class Select(_SelectBaseMixin, FromClause): return except_(self, other, **kwargs) def except_all(self, other, **kwargs): - """return a SQL EXCEPT ALL of this select() construct against the given selectable.""" - + """return a SQL EXCEPT ALL of this select() construct against the given + selectable. + + """ return except_all(self, other, **kwargs) def intersect(self, other, **kwargs): - """return a SQL INTERSECT of this select() construct against the given selectable.""" - + """return a SQL INTERSECT of this select() construct against the given + selectable. + + """ return intersect(self, other, **kwargs) def intersect_all(self, other, **kwargs): - """return a SQL INTERSECT ALL of this select() construct against the given selectable.""" - + """return a SQL INTERSECT ALL of this select() construct against the given + selectable. + + """ return intersect_all(self, other, **kwargs) def bind(self): @@ -3581,7 +3768,7 @@ class _UpdateBase(ClauseElement): supports_execution = True _autocommit = True - + def _generate(self): s = self.__class__.__new__(self.__class__) s.__dict__ = self.__dict__.copy() @@ -3597,8 +3784,10 @@ class _UpdateBase(ClauseElement): return parameters def params(self, *arg, **kw): - raise NotImplementedError("params() is not supported for INSERT/UPDATE/DELETE statements." - " To set the values for an INSERT or UPDATE statement, use stmt.values(**parameters).") + raise NotImplementedError( + "params() is not supported for INSERT/UPDATE/DELETE statements." + " To set the values for an INSERT or UPDATE statement, use" + " stmt.values(**parameters).") def bind(self): return self._bind or self.table.bind @@ -3607,6 +3796,51 @@ class _UpdateBase(ClauseElement): self._bind = bind bind = property(bind, _set_bind) + _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning') + def _process_deprecated_kw(self, kwargs): + for k in list(kwargs): + m = self._returning_re.match(k) + if m: + self._returning = kwargs.pop(k) + util.warn_deprecated( + "The %r argument is deprecated. Please use statement.returning(col1, col2, ...)" % k + ) + return kwargs + + @_generative + def returning(self, *cols): + """Add a RETURNING or equivalent clause to this statement. + + The given list of columns represent columns within the table + that is the target of the INSERT, UPDATE, or DELETE. Each + element can be any column expression. ``Table`` objects + will be expanded into their individual columns. + + Upon compilation, a RETURNING clause, or database equivalent, + will be rendered within the statement. For INSERT and UPDATE, + the values are the newly inserted/updated values. For DELETE, + the values are those of the rows which were deleted. + + Upon execution, the values of the columns to be returned + are made available via the result set and can be iterated + using ``fetchone()`` and similar. For DBAPIs which do not + natively support returning values (i.e. cx_oracle), + SQLAlchemy will approximate this behavior at the result level + so that a reasonable amount of behavioral neutrality is + provided. + + Note that not all databases/DBAPIs + support RETURNING. For those backends with no support, + an exception is raised upon compilation and/or execution. + For those who do support it, the functionality across backends + varies greatly, including restrictions on executemany() + and other statements which return multiple rows. Please + read the documentation notes for the database in use in + order to determine the availability of RETURNING. + + """ + self._returning = cols + class _ValuesBase(_UpdateBase): __visit_name__ = 'values_base' @@ -3617,14 +3851,15 @@ class _ValuesBase(_UpdateBase): @_generative def values(self, *args, **kwargs): - """specify the VALUES clause for an INSERT statement, or the SET clause for an UPDATE. + """specify the VALUES clause for an INSERT statement, or the SET clause for an + UPDATE. \**kwargs key=<somevalue> arguments \*args - A single dictionary can be sent as the first positional argument. This allows - non-string based keys, such as Column objects, to be used. + A single dictionary can be sent as the first positional argument. This + allows non-string based keys, such as Column objects, to be used. """ if args: @@ -3648,16 +3883,25 @@ class Insert(_ValuesBase): """ __visit_name__ = 'insert' - def __init__(self, table, values=None, inline=False, bind=None, prefixes=None, **kwargs): + def __init__(self, + table, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + **kwargs): _ValuesBase.__init__(self, table, values) self._bind = bind self.select = None self.inline = inline + self._returning = returning if prefixes: self._prefixes = [_literal_as_text(p) for p in prefixes] else: self._prefixes = [] - self.kwargs = kwargs + + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self.select is not None: @@ -3688,15 +3932,24 @@ class Update(_ValuesBase): """ __visit_name__ = 'update' - def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs): + def __init__(self, + table, + whereclause, + values=None, + inline=False, + bind=None, + returning=None, + **kwargs): _ValuesBase.__init__(self, table, values) self._bind = bind + self._returning = returning if whereclause: self._whereclause = _literal_as_text(whereclause) else: self._whereclause = None self.inline = inline - self.kwargs = kwargs + + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self._whereclause is not None: @@ -3711,9 +3964,10 @@ class Update(_ValuesBase): @_generative def where(self, whereclause): - """return a new update() construct with the given expression added to its WHERE clause, joined - to the existing clause via AND, if any.""" - + """return a new update() construct with the given expression added to its WHERE + clause, joined to the existing clause via AND, if any. + + """ if self._whereclause is not None: self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) else: @@ -3729,15 +3983,22 @@ class Delete(_UpdateBase): __visit_name__ = 'delete' - def __init__(self, table, whereclause, bind=None, **kwargs): + def __init__(self, + table, + whereclause, + bind=None, + returning =None, + **kwargs): self._bind = bind self.table = table + self._returning = returning + if whereclause: self._whereclause = _literal_as_text(whereclause) else: self._whereclause = None - self.kwargs = kwargs + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self._whereclause is not None: diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 7c21e8233..879f0f3e5 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -4,8 +4,13 @@ """Defines operators used in SQL expressions.""" from operator import ( - and_, or_, inv, add, mul, sub, div, mod, truediv, lt, le, ne, gt, ge, eq + and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq ) + +# Py2K +from operator import (div,) +# end Py2K + from sqlalchemy.util import symbol @@ -88,7 +93,10 @@ _largest = symbol('_largest') _PRECEDENCE = { from_: 15, mul: 7, + truediv: 7, + # Py2K div: 7, + # end Py2K mod: 7, add: 6, sub: 6, diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index a5bd497ae..4471d4fb0 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -34,13 +34,10 @@ class VisitableType(type): """ def __init__(cls, clsname, bases, clsdict): - if cls.__name__ == 'Visitable': + if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): super(VisitableType, cls).__init__(clsname, bases, clsdict) return - assert hasattr(cls, '__visit_name__'), "`Visitable` descendants " \ - "should define `__visit_name__`" - # set up an optimized visit dispatch function # for use by the compiler visit_name = cls.__visit_name__ |
