diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 50 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 24 |
2 files changed, 61 insertions, 13 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 305e8e831..03b3fd042 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -10,13 +10,20 @@ PostgreSQL supports partial indexes. To create them pass a posgres_where option to the Index constructor:: Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10) + +PostgreSQL 8.2+ supports returning a result set from inserts and updates. +To use this pass the column/expression list to the postgres_returning +parameter when creating the queries:: + + raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), + postgres_returning=[empl.c.id, empl.c.salary]).execute().fetchall() """ import re, random, warnings, string from sqlalchemy import sql, schema, exceptions, util from sqlalchemy.engine import base, default -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, expression from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes @@ -198,13 +205,27 @@ def descriptor(): ]} SELECT_RE = re.compile( - r'\s*(?:SELECT|FETCH)', + r'\s*(?:SELECT|FETCH|(UPDATE|INSERT))', + re.I | re.UNICODE) + +RETURNING_RE = re.compile( + 'RETURNING', + re.I | re.UNICODE) + +# This finds if the RETURNING is not inside a quoted/commented values. Handles string literals, +# quoted identifiers, dollar quotes, SQL comments and C style multiline comments. This does not +# handle correctly nested C style quotes, lets hope no one does the following: +# UPDATE tbl SET x=y /* foo /* bar */ RETURNING */ +RETURNING_QUOTED_RE = re.compile( + '\\s*(?:UPDATE|INSERT)\\s(?:[^\'"$/-]|-(?!-)|/(?!\\*)|"(?:[^"]|"")*"|\'(?:[^\']|\'\')*\'|\\$(?P<dquote>[^$]*)\\$.*?\\$(?P=dquote)\\$|--[^\n]*\n|/\\*([^*]|\\*(?!/))*\\*/)*\\sRETURNING', re.I | re.UNICODE) class PGExecutionContext(default.DefaultExecutionContext): def is_select(self): - return SELECT_RE.match(self.statement) + m = SELECT_RE.match(self.statement) + return m and (not m.group(1) or (RETURNING_RE.search(self.statement) + and RETURNING_QUOTED_RE.match(self.statement))) def create_cursor(self): # executing a default or Sequence standalone creates an execution context without a statement. @@ -598,6 +619,29 @@ class PGCompiler(compiler.DefaultCompiler): else: return super(PGCompiler, self).for_update_clause(select) + def _append_returning(self, text, stmt): + returning_cols = stmt.kwargs.get('postgres_returning', None) + if returning_cols: + def flatten_columnlist(collist): + for c in collist: + if isinstance(c, expression.Selectable): + for co in c.columns: + yield co + else: + yield c + columns = [self.process(c) for c in flatten_columnlist(returning_cols)] + text += ' RETURNING ' + string.join(columns, ', ') + + return text + + def visit_update(self, update_stmt): + text = super(PGCompiler, self).visit_update(update_stmt) + return self._append_returning(text, update_stmt) + + def visit_insert(self, insert_stmt): + text = super(PGCompiler, self).visit_insert(insert_stmt) + return self._append_returning(text, insert_stmt) + class PGSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 9a9cf65d2..6f3ee94ab 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -249,7 +249,7 @@ def subquery(alias, *args, **kwargs): return Select(*args, **kwargs).alias(alias) -def insert(table, values=None, inline=False): +def insert(table, values=None, inline=False, **kwargs): """Return an [sqlalchemy.sql.expression#Insert] clause element. Similar functionality is available via the ``insert()`` method on @@ -287,9 +287,9 @@ def insert(table, values=None, inline=False): against the ``INSERT`` statement. """ - return Insert(table, values, inline=inline) + return Insert(table, values, inline=inline, **kwargs) -def update(table, whereclause=None, values=None, inline=False): +def update(table, whereclause=None, values=None, inline=False, **kwargs): """Return an [sqlalchemy.sql.expression#Update] clause element. Similar functionality is available via the ``update()`` method on @@ -332,7 +332,7 @@ def update(table, whereclause=None, values=None, inline=False): against the ``UPDATE`` statement. """ - return Update(table, whereclause=whereclause, values=values, inline=inline) + return Update(table, whereclause=whereclause, values=values, inline=inline, **kwargs) def delete(table, whereclause = None, **kwargs): """Return a [sqlalchemy.sql.expression#Delete] clause element. @@ -2699,11 +2699,11 @@ class TableClause(FromClause): def select(self, whereclause = None, **params): return select([self], whereclause, **params) - def insert(self, values=None, inline=False): - return insert(self, values=values, inline=inline) + def insert(self, values=None, inline=False, **kwargs): + return insert(self, values=values, inline=inline, **kwargs) - def update(self, whereclause=None, values=None, inline=False): - return update(self, whereclause=whereclause, values=values, inline=inline) + def update(self, whereclause=None, values=None, inline=False, **kwargs): + return update(self, whereclause=whereclause, values=values, inline=inline, **kwargs) def delete(self, whereclause=None): return delete(self, whereclause) @@ -3356,12 +3356,14 @@ class _UpdateBase(ClauseElement): return self.table.bind class Insert(_UpdateBase): - def __init__(self, table, values=None, inline=False): + def __init__(self, table, values=None, inline=False, **kwargs): self.table = table self.select = None self.inline=inline self.parameters = self._process_colparams(values) + self.kwargs = kwargs + def get_children(self, **kwargs): if self.select is not None: return self.select, @@ -3383,12 +3385,14 @@ class Insert(_UpdateBase): return u class Update(_UpdateBase): - def __init__(self, table, whereclause, values=None, inline=False): + def __init__(self, table, whereclause, values=None, inline=False, **kwargs): self.table = table self._whereclause = whereclause self.inline = inline self.parameters = self._process_colparams(values) + self.kwargs = kwargs + def get_children(self, **kwargs): if self._whereclause is not None: return self._whereclause, |
