summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-10-23 01:08:02 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-10-23 01:08:02 +0000
commita43a0e8b68c28b72404e85e4b4e8999443dd3fc5 (patch)
treef203e49eb40aaa6858eb969f625eeb899ac9c3c6 /lib/sqlalchemy/sql
parent9ae821ee660a2d03cae591798c05cfdbd8bb3ca6 (diff)
downloadsqlalchemy-a43a0e8b68c28b72404e85e4b4e8999443dd3fc5.tar.gz
- insert() and update() constructs can now embed bindparam()
objects using names that match the keys of columns. These bind parameters will circumvent the usual route to those keys showing up in the VALUES or SET clause of the generated SQL. [ticket:1579]
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 4c3130879..5f5b31c68 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -47,7 +47,7 @@ RESERVED_WORDS = set([
'using', 'verbose', 'when', 'where'])
LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
-ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]')
+ILLEGAL_INITIAL_CHARACTERS = set(xrange(0, 10)).union(['$'])
BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE)
@@ -776,12 +776,17 @@ class SQLCompiler(engine.Compiled):
self.prefetch = []
self.returning = []
+ # get the keys of explicitly constructed bindparam() objects
+ bind_names = set(b.key for b in visitors.iterate(stmt, {}) if b.__visit_name__ == 'bindparam')
+ if stmt.parameters:
+ bind_names.update(stmt.parameters)
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
return [
(c, self._create_crud_bind_param(c, None, required=True))
- for c in stmt.table.columns
+ for c in stmt.table.columns if c.key not in bind_names
]
required = object()
@@ -792,7 +797,7 @@ class SQLCompiler(engine.Compiled):
parameters = {}
else:
parameters = dict((sql._column_as_key(key), required)
- for key in self.column_keys)
+ for key in self.column_keys if key not in bind_names)
if stmt.parameters is not None:
for k, v in stmt.parameters.iteritems():
@@ -1312,7 +1317,7 @@ class IdentifierPreparer(object):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (lc_value in self.reserved_words
- or self.illegal_initial_characters.match(value[0])
+ or value[0] in self.illegal_initial_characters
or not self.legal_characters.match(unicode(value))
or (lc_value != value))