summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-03-20 12:49:28 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2011-03-20 12:49:28 -0400
commit90335a89a98df23db7a3ae1233eb4fbb5743d2e8 (patch)
tree9a4ac236f83696709bd355dcac22552aeb177694 /lib/sqlalchemy
parent75c78aa714ca55818f0ba12a67cf2f77927b68f7 (diff)
downloadsqlalchemy-90335a89a98df23db7a3ae1233eb4fbb5743d2e8.tar.gz
- Added new generic function "next_value()", accepts
a Sequence object as its argument and renders the appropriate "next value" generation string on the target platform, if supported. Also provides ".next_value()" method on Sequence itself. [ticket:2085] - added tests for all the conditions described in [ticket:2085] - postgresql dialect will exec/compile a Sequence that has "optional=True". the optional flag is now only checked specifically in the context of a Table primary key evaulation. - func.next_value() or other SQL expression can be embedded directly into an insert() construct, and if implicit or explicit "returning" is used in conjunction with a primary key column, the newly generated value will be present in result.inserted_primary_key. [ticket:2084]
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py3
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py19
-rw-r--r--lib/sqlalchemy/schema.py58
-rw-r--r--lib/sqlalchemy/sql/compiler.py26
-rw-r--r--lib/sqlalchemy/sql/expression.py10
-rw-r--r--lib/sqlalchemy/sql/functions.py26
6 files changed, 101 insertions, 41 deletions
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index d3c1bc139..72411d735 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -392,7 +392,8 @@ class OracleCompiler(compiler.SQLCompiler):
return ""
def default_from(self):
- """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
+ """Called when a ``SELECT`` statement has no froms,
+ and no ``FROM`` clause is to be appended.
The Oracle compiler tacks a "FROM DUAL" to the statement.
"""
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 8bceeef65..cc2f461f9 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -461,10 +461,7 @@ class PGCompiler(compiler.SQLCompiler):
return value
def visit_sequence(self, seq):
- if seq.optional:
- return None
- else:
- return "nextval('%s')" % self.preparer.format_sequence(seq)
+ return "nextval('%s')" % self.preparer.format_sequence(seq)
def limit_clause(self, select):
text = ""
@@ -717,23 +714,19 @@ class DropEnumType(schema._CreateDropBase):
class PGExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
- if not seq.optional:
- return self._execute_scalar(("select nextval('%s')" % \
- self.dialect.identifier_preparer.format_sequence(seq)), type_)
- else:
- return None
+ return self._execute_scalar(("select nextval('%s')" % \
+ self.dialect.identifier_preparer.format_sequence(seq)), type_)
def get_insert_default(self, column):
if column.primary_key and column is column.table._autoincrement_column:
- if (isinstance(column.server_default, schema.DefaultClause) and
- column.server_default.arg is not None):
+ if column.server_default and column.server_default.has_argument:
# pre-execute passive defaults on primary key columns
return self._execute_scalar("select %s" %
- column.server_default.arg, column.type)
+ column.server_default.arg, column.type)
elif (column.default is None or
- (isinstance(column.default, schema.Sequence) and
+ (column.default.is_sequence and
column.default.optional)):
# execute the sequence associated with a SERIAL primary
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 70d9013d6..bc3eac213 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -750,7 +750,7 @@ class Column(SchemaItem, expression.ColumnClause):
if isinstance(self.default, str):
# end Py2K
util.warn("Unicode column received non-unicode "
- "default value.")
+ "default value.")
args.append(ColumnDefault(self.default))
if self.server_default is not None:
@@ -1266,7 +1266,16 @@ class ForeignKey(SchemaItem):
self.constraint._set_parent_with_dispatch(table)
table.foreign_keys.add(self)
-class DefaultGenerator(SchemaItem):
+class _NotAColumnExpr(object):
+ def _not_a_column_expr(self):
+ raise exc.InvalidRequestError(
+ "This %s cannot be used directly "
+ "as a column expression." % self.__class__.__name__)
+
+ __clause_element__ = self_group = lambda self: self._not_a_column_expr()
+ _from_objects = property(lambda self: self._not_a_column_expr())
+
+class DefaultGenerator(_NotAColumnExpr, SchemaItem):
"""Base class for column *default* values."""
__visit_name__ = 'default_generator'
@@ -1392,26 +1401,26 @@ class ColumnDefault(DefaultGenerator):
class Sequence(DefaultGenerator):
"""Represents a named database sequence.
-
+
The :class:`.Sequence` object represents the name and configurational
parameters of a database sequence. It also represents
a construct that can be "executed" by a SQLAlchemy :class:`.Engine`
or :class:`.Connection`, rendering the appropriate "next value" function
for the target database and returning a result.
-
+
The :class:`.Sequence` is typically associated with a primary key column::
-
+
some_table = Table('some_table', metadata,
Column('id', Integer, Sequence('some_table_seq'), primary_key=True)
)
-
+
When CREATE TABLE is emitted for the above :class:`.Table`, if the
target platform supports sequences, a CREATE SEQUENCE statement will
be emitted as well. For platforms that don't support sequences,
the :class:`.Sequence` construct is ignored.
-
+
See also: :class:`.CreateSequence` :class:`.DropSequence`
-
+
"""
__visit_name__ = 'sequence'
@@ -1422,7 +1431,7 @@ class Sequence(DefaultGenerator):
optional=False, quote=None, metadata=None,
for_update=False):
"""Construct a :class:`.Sequence` object.
-
+
:param name: The name of the sequence.
:param start: the starting index of the sequence. This value is
used when the CREATE SEQUENCE command is emitted to the database
@@ -1455,7 +1464,7 @@ class Sequence(DefaultGenerator):
DROP SEQUENCE DDL commands will be emitted corresponding to this
:class:`.Sequence` when :meth:`.MetaData.create_all` and
:meth:`.MetaData.drop_all` are invoked (new in 0.7).
-
+
Note that when a :class:`.Sequence` is applied to a :class:`.Column`,
the :class:`.Sequence` is automatically associated with the
:class:`.MetaData` object of that column's parent :class:`.Table`,
@@ -1467,7 +1476,7 @@ class Sequence(DefaultGenerator):
with a :class:`.Column`, should be invoked for UPDATE statements
on that column's table, rather than for INSERT statements, when
no value is otherwise present for that column in the statement.
-
+
"""
super(Sequence, self).__init__(for_update=for_update)
self.name = name
@@ -1488,6 +1497,14 @@ class Sequence(DefaultGenerator):
def is_clause_element(self):
return False
+ def next_value(self):
+ """Return a :class:`.next_value` function element
+ which will render the appropriate increment function
+ for this :class:`.Sequence` within any SQL expression.
+
+ """
+ return expression.func.next_value(self, bind=self.bind)
+
def __repr__(self):
return "Sequence(%s)" % ', '.join(
[repr(self.name)] +
@@ -1526,8 +1543,16 @@ class Sequence(DefaultGenerator):
bind = _bind_or_error(self)
bind.drop(self, checkfirst=checkfirst)
+ def _not_a_column_expr(self):
+ raise exc.InvalidRequestError(
+ "This %s cannot be used directly "
+ "as a column expression. Use func.next_value(sequence) "
+ "to produce a 'next value' function that's usable "
+ "as a column element."
+ % self.__class__.__name__)
-class FetchedValue(events.SchemaEventTarget):
+
+class FetchedValue(_NotAColumnExpr, events.SchemaEventTarget):
"""A marker for a transparent database-side default.
Use :class:`.FetchedValue` when the database is configured
@@ -1544,6 +1569,7 @@ class FetchedValue(events.SchemaEventTarget):
"""
is_server_default = True
reflected = False
+ has_argument = False
def __init__(self, for_update=False):
self.for_update = for_update
@@ -1581,6 +1607,8 @@ class DefaultClause(FetchedValue):
"""
+ has_argument = True
+
def __init__(self, arg, for_update=False, _reflected=False):
util.assert_arg_type(arg, (basestring,
expression.ClauseElement,
@@ -2508,7 +2536,7 @@ class DDLElement(expression.Executable, expression.ClauseElement):
Optional keyword argument - a list of Table objects which are to
be created/ dropped within a MetaData.create_all() or drop_all()
method call.
-
+
:state:
Optional keyword argument - will be the ``state`` argument
passed to this function.
@@ -2517,13 +2545,13 @@ class DDLElement(expression.Executable, expression.ClauseElement):
Keyword argument, will be True if the 'checkfirst' flag was
set during the call to ``create()``, ``create_all()``,
``drop()``, ``drop_all()``.
-
+
If the callable returns a true value, the DDL statement will be
executed.
:param state: any value which will be passed to the callable_
as the ``state`` keyword argument.
-
+
See also:
:class:`.DDLEvents`
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d6a020bdc..7547e1662 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -492,6 +492,14 @@ class SQLCompiler(engine.Compiled):
return ".".join(func.packagenames + [name]) % \
{'expr':self.function_argspec(func, **kwargs)}
+ def visit_next_value_func(self, next_value, **kw):
+ return self.visit_sequence(next_value.sequence)
+
+ def visit_sequence(self, sequence):
+ raise NotImplementedError(
+ "Dialect '%s' does not support sequence increments." % self.dialect.name
+ )
+
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
@@ -926,9 +934,6 @@ class SQLCompiler(engine.Compiled):
join.onclause._compiler_dispatch(self, **kwargs)
)
- def visit_sequence(self, seq):
- return None
-
def visit_insert(self, insert_stmt):
self.isinsert = True
colparams = self._get_colparams(insert_stmt)
@@ -1075,6 +1080,9 @@ class SQLCompiler(engine.Compiled):
if sql._is_literal(value):
value = self._create_crud_bind_param(
c, value, required=value is required)
+ elif c.primary_key and implicit_returning:
+ self.returning.append(c)
+ value = self.process(value.self_group())
else:
self.postfetch.append(c)
value = self.process(value.self_group())
@@ -1092,8 +1100,10 @@ class SQLCompiler(engine.Compiled):
if implicit_returning:
if c.default is not None:
if c.default.is_sequence:
- proc = self.process(c.default)
- if proc is not None:
+ if self.dialect.supports_sequences and \
+ (not c.default.optional or \
+ not self.dialect.sequences_optional):
+ proc = self.process(c.default)
values.append((c, proc))
self.returning.append(c)
elif c.default.is_clause_element:
@@ -1124,8 +1134,10 @@ class SQLCompiler(engine.Compiled):
elif c.default is not None:
if c.default.is_sequence:
- proc = self.process(c.default)
- if proc is not None:
+ if self.dialect.supports_sequences and \
+ (not c.default.optional or \
+ not self.dialect.sequences_optional):
+ proc = self.process(c.default)
values.append((c, proc))
if not c.primary_key:
self.postfetch.append(c)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 9aed957d2..d49f12150 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1178,14 +1178,16 @@ def _column_as_key(element):
return element.key
def _literal_as_text(element):
- if hasattr(element, '__clause_element__'):
+ if isinstance(element, Visitable):
+ return element
+ elif hasattr(element, '__clause_element__'):
return element.__clause_element__()
elif isinstance(element, basestring):
return _TextClause(unicode(element))
- elif not isinstance(element, Visitable):
- raise exc.ArgumentError("SQL expression object or string expected.")
else:
- return element
+ raise exc.ArgumentError(
+ "SQL expression object or string expected."
+ )
def _clause_element_as_expr(element):
if hasattr(element, '__clause_element__'):
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 10eaa577b..717816656 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import types as sqltypes
+from sqlalchemy import types as sqltypes, schema
from sqlalchemy.sql.expression import (
ClauseList, Function, _literal_as_binds, text, _type_from_args
)
@@ -29,6 +29,29 @@ class GenericFunction(Function):
self.type = sqltypes.to_instance(
type_ or getattr(self, '__return_type__', None))
+
+class next_value(Function):
+ """Represent the 'next value', given a :class:`.Sequence`
+ as it's single argument.
+
+ Compiles into the appropriate function on each backend,
+ or will raise NotImplementedError if used on a backend
+ that does not provide support for sequences.
+
+ """
+ type = sqltypes.Integer()
+ name = "next_value"
+
+ def __init__(self, seq, **kw):
+ assert isinstance(seq, schema.Sequence), \
+ "next_value() accepts a Sequence object as input."
+ self._bind = kw.get('bind', None)
+ self.sequence = seq
+
+ @property
+ def _from_objects(self):
+ return []
+
class AnsiFunction(GenericFunction):
def __init__(self, **kwargs):
GenericFunction.__init__(self, **kwargs)
@@ -52,6 +75,7 @@ class min(ReturnTypeFromArgs):
class sum(ReturnTypeFromArgs):
pass
+
class now(GenericFunction):
__return_type__ = sqltypes.DateTime