summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-08-12 17:28:15 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-08-12 17:28:15 +0000
commit9e8fad2abcce364253352f042836bf58ce8f4f81 (patch)
tree5058c15280a2e56d454670deeb7a53dd8b6b1f67 /lib/sqlalchemy
parentfb88b031d916ea91ce9af760a67ea27e00113c14 (diff)
downloadsqlalchemy-9e8fad2abcce364253352f042836bf58ce8f4f81.tar.gz
quoting facilities set up so that database-specific quoting can be
turned on for individual table, schema, and column identifiers when used in all queries/creates/drops. Enabled via "quote=True" in Table or Column, as well as "quote_schema=True" in Table. Thanks to Aaron Spike for his excellent efforts. [ticket:155]
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/ansisql.py109
-rw-r--r--lib/sqlalchemy/databases/firebird.py10
-rw-r--r--lib/sqlalchemy/databases/mssql.py15
-rw-r--r--lib/sqlalchemy/databases/mysql.py15
-rw-r--r--lib/sqlalchemy/databases/postgres.py8
-rw-r--r--lib/sqlalchemy/databases/sqlite.py12
-rw-r--r--lib/sqlalchemy/schema.py11
-rw-r--r--lib/sqlalchemy/sql.py4
8 files changed, 156 insertions, 28 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index b85f67d47..e1791324d 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -42,6 +42,11 @@ class ANSIDialect(default.DefaultDialect):
def compiler(self, statement, parameters, **kwargs):
return ANSICompiler(self, statement, parameters, **kwargs)
+ def preparer(self):
+ """return an IdenfifierPreparer.
+
+ This object is used to format table and column names including proper quoting and case conventions."""
+ return ANSIIdentifierPreparer()
class ANSICompiler(sql.Compiled):
"""default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
@@ -69,6 +74,7 @@ class ANSICompiler(sql.Compiled):
self.bindtemplate = ":%s"
self.paramstyle = dialect.paramstyle
self.positional = dialect.positional
+ self.preparer = dialect.preparer()
def after_compile(self):
# this re will search for params like :param
@@ -170,19 +176,18 @@ class ANSICompiler(sql.Compiled):
# for this column which is used to translate result set values
self.typemap.setdefault(column.name.lower(), column.type)
if column.table is None or not column.table.named_with_column():
- self.strings[column] = column.name
+ self.strings[column] = self.preparer.format_column(column)
else:
if column.table.oid_column is column:
n = self.dialect.oid_column_name()
if n is not None:
- self.strings[column] = "%s.%s" % (column.table.name, n)
+ self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
elif len(column.table.primary_key) != 0:
- self.strings[column] = "%s.%s" % (column.table.name, column.table.primary_key[0].name)
+ self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0])
else:
self.strings[column] = None
else:
- self.strings[column] = "%s.%s" % (column.table.name, column.name)
-
+ self.strings[column] = self.preparer.format_column_with_table(column)
def visit_fromclause(self, fromclause):
self.froms[fromclause] = fromclause.from_name
@@ -427,7 +432,7 @@ class ANSICompiler(sql.Compiled):
return " OFFSET " + str(select.offset)
def visit_table(self, table):
- self.froms[table] = table.fullname
+ self.froms[table] = self.preparer.format_table(table)
self.strings[table] = ""
def visit_join(self, join):
@@ -501,7 +506,7 @@ class ANSICompiler(sql.Compiled):
else:
return self.get_str(p)
- text = ("INSERT INTO " + insert_stmt.table.fullname + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
+ text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
" VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")")
self.strings[insert_stmt] = text
@@ -532,7 +537,7 @@ class ANSICompiler(sql.Compiled):
else:
return self.get_str(p)
- text = "UPDATE " + update_stmt.table.fullname + " SET " + string.join(["%s=%s" % (c[0].name, create_param(c[1])) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(c[1])) for c in colparams], ', ')
if update_stmt.whereclause:
text += " WHERE " + self.get_str(update_stmt.whereclause)
@@ -596,7 +601,7 @@ class ANSICompiler(sql.Compiled):
return values
def visit_delete(self, delete_stmt):
- text = "DELETE FROM " + delete_stmt.table.fullname
+ text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
if delete_stmt.whereclause:
text += " WHERE " + self.get_str(delete_stmt.whereclause)
@@ -612,6 +617,8 @@ class ANSISchemaGenerator(engine.SchemaIterator):
super(ANSISchemaGenerator, self).__init__(engine, proxy, **params)
self.checkfirst = checkfirst
self.connection = connection
+ self.preparer = self.engine.dialect.preparer()
+
def get_column_specification(self, column, first_pk=False):
raise NotImplementedError()
@@ -622,7 +629,7 @@ class ANSISchemaGenerator(engine.SchemaIterator):
if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name):
return
- self.append("\nCREATE TABLE " + table.fullname + " (")
+ self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
separator = "\n"
@@ -665,16 +672,16 @@ class ANSISchemaGenerator(engine.SchemaIterator):
if len(constraint) == 0:
return
self.append(", \n")
- self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in constraint],', '))
+ self.append("\tPRIMARY KEY (%s)" % string.join([self.preparer.format_column(c) for c in constraint],', '))
def visit_foreign_key_constraint(self, constraint):
self.append(", \n\t ")
if constraint.name is not None:
self.append("CONSTRAINT %s " % constraint.name)
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- string.join([f.parent.name for f in constraint.elements], ', '),
- list(constraint.elements)[0].column.table.fullname,
- string.join([f.column.name for f in constraint.elements], ', ')
+ string.join([self.preparer.format_column(f.parent) for f in constraint.elements], ', '),
+ self.preparer.format_table(list(constraint.elements)[0].column.table),
+ string.join([self.preparer.format_column(f.column) for f in constraint.elements], ', ')
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
@@ -689,16 +696,16 @@ class ANSISchemaGenerator(engine.SchemaIterator):
if index.unique:
self.append('UNIQUE ')
self.append('INDEX %s ON %s (%s)' \
- % (index.name, index.table.fullname,
- string.join([c.name for c in index.columns], ', ')))
+ % (index.name, self.preparer.format_table(index.table),
+ string.join([self.preparer.format_column(c) for c in index.columns], ', ')))
self.execute()
-
class ANSISchemaDropper(engine.SchemaIterator):
def __init__(self, engine, proxy, connection=None, checkfirst=False, **params):
super(ANSISchemaDropper, self).__init__(engine, proxy, **params)
self.checkfirst = checkfirst
self.connection = connection
+ self.preparer = self.engine.dialect.preparer()
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
@@ -709,9 +716,73 @@ class ANSISchemaDropper(engine.SchemaIterator):
# no need to drop them individually
if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name):
return
- self.append("\nDROP TABLE " + table.fullname)
+ self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
-
class ANSIDefaultRunner(engine.DefaultRunner):
pass
+
+class ANSIIdentifierPreparer(object):
+ """Transforms identifiers into ANSI-Compliant delimited identifiers where required"""
+ def __init__(self, initial_quote='"', final_quote=None, omit_schema=False):
+ """Constructs a new ANSIIdentifierPreparer object.
+
+ initial_quote - Character that begins a delimited identifier
+ final_quote - Caracter that ends a delimited identifier. defaults to initial_quote.
+
+ omit_schema - prevent prepending schema name. useful for databases that do not support schemae
+ """
+ self.initial_quote = initial_quote
+ self.final_quote = final_quote or self.initial_quote
+ self.omit_schema = omit_schema
+
+ def _escape_identifier(self, value):
+ return value.replace('"', '""')
+
+ def _quote_identifier(self, value):
+ return self.initial_quote + self._escape_identifier(value) + self.final_quote
+
+ def _fold_identifier_case(self, value):
+ return value
+ # ANSI SQL calls for the case of all unquoted identifiers to be folded to UPPER.
+ # some tests would need to be rewritten if this is done.
+ #return value.upper()
+
+ def _prepare_table(self, table, use_schema=False):
+ names = []
+ if table.quote:
+ names.append(self._quote_identifier(table.name))
+ else:
+ names.append(self._fold_identifier_case(table.name))
+
+ if not self.omit_schema and use_schema and table.schema:
+ if table.quote_schema:
+ names.insert(0, self._quote_identifier(table.schema))
+ else:
+ names.insert(0, self._fold_identifier_case(table.schema))
+
+ return ".".join(names)
+
+ def _prepare_column(self, column, use_table=True, **kwargs):
+ names = []
+ if column.quote:
+ names.append(self._quote_identifier(column.name))
+ else:
+ names.append(self._fold_identifier_case(column.name))
+
+ if use_table:
+ names.insert(0, self._prepare_table(column.table, **kwargs))
+
+ return ".".join(names)
+
+ def format_table(self, table, use_schema=True):
+ """Prepare a quoted table and schema name"""
+ return self._prepare_table(table, use_schema=use_schema)
+
+ def format_column(self, column):
+ """Prepare a quoted column name"""
+ return self._prepare_column(column, use_table=False)
+
+ def format_column_with_table(self, column):
+ """Prepare a quoted column name with table name"""
+ return self._prepare_column(column)
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index bef185597..1cad4f37a 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -97,7 +97,10 @@ class FireBirdExecutionContext(default.DefaultExecutionContext):
def defaultrunner(self, proxy):
return FBDefaultRunner(self, proxy)
-
+
+ def preparer(self):
+ return FBIdentifierPreparer()
+
class FireBirdDialect(ansisql.ANSIDialect):
def __init__(self, module = None, **params):
global _initialized_kb
@@ -298,7 +301,7 @@ class FBCompiler(ansisql.ANSICompiler):
class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = column.name
+ colspec = self.preparer.format_column(column)
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
@@ -325,5 +328,8 @@ class FBDefaultRunner(ansisql.ANSIDefaultRunner):
def visit_sequence(self, seq):
return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0]
+class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def __init__(self):
+ super(FBIdentifierPreparer,self).__init__(omit_schema=True)
dialect = FireBirdDialect
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 690b71863..414ca87e7 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -268,6 +268,9 @@ class MSSQLDialect(ansisql.ANSIDialect):
def defaultrunner(self, engine, proxy):
return MSSQLDefaultRunner(engine, proxy)
+ def preparer(self):
+ return MSSQLIdentifierPreparer()
+
def get_default_schema_name(self):
return "dbo"
@@ -510,7 +513,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if column.primary_key and isinstance(column.type, sqltypes.Integer):
@@ -538,4 +541,14 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
pass
+class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def __init__(self):
+ super(MSSQLIdentifierPreparer, self).__init__(initial_quote='[', final_quote=']')
+ def _escape_identifier(self, value):
+ #TODO: determin MSSQL's escapeing rules
+ return value
+ def _fold_identifier_case(self, value):
+ #TODO: determin MSSQL's case folding rules
+ return value
+
dialect = MSSQLDialect
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index df661daac..56f2125ac 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -298,6 +298,9 @@ class MySQLDialect(ansisql.ANSIDialect):
def schemadropper(self, *args, **kwargs):
return MySQLSchemaDropper(*args, **kwargs)
+ def preparer(self):
+ return MySQLIdentifierPreparer()
+
def do_rollback(self, connection):
# some versions of MySQL just dont support rollback() at all....
try:
@@ -428,7 +431,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
- colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -452,4 +455,14 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
self.execute()
+class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def __init__(self):
+ super(MySQLIdentifierPreparer, self).__init__(initial_quote='`')
+ def _escape_identifier(self, value):
+ #TODO: determin MySQL's escaping rules
+ return value
+ def _fold_identifier_case(self, value):
+ #TODO: determin MySQL's case folding rules
+ return value
+
dialect = MySQLDialect
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 8368c8931..4efe0e162 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -245,6 +245,8 @@ class PGDialect(ansisql.ANSIDialect):
return PGSchemaDropper(*args, **kwargs)
def defaultrunner(self, engine, proxy):
return PGDefaultRunner(engine, proxy)
+ def preparer(self):
+ return PGIdentifierPreparer()
def get_default_schema_name(self, connection):
if not hasattr(self, '_default_schema_name'):
@@ -331,7 +333,7 @@ class PGCompiler(ansisql.ANSICompiler):
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = column.name
+ colspec = self.preparer.format_column(column)
if column.primary_key and not column.foreign_key and isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
@@ -382,4 +384,8 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
else:
return None
+class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def _fold_identifier_case(self, value):
+ return value.lower()
+
dialect = PGDialect
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index 192b90561..1727dce94 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -139,6 +139,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
return SQLiteCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
return SQLiteSchemaGenerator(*args, **kwargs)
+ def preparer(self):
+ return SQLiteIdentifierPreparer()
def create_connect_args(self, url):
filename = url.database or ':memory:'
return ([filename], url.query)
@@ -148,7 +150,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
return SQLiteExecutionContext(self)
def last_inserted_ids(self):
return self.context.last_inserted_ids
-
+
def oid_column_name(self):
return "oid"
@@ -276,7 +278,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -294,6 +296,10 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
# self.append("\tUNIQUE (%s)" % string.join([c.name for c in constraint],', '))
# else:
# super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
-
+
+class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def __init__(self):
+ super(SQLiteIdentifierPreparer, self).__init__(omit_schema=True)
+
dialect = SQLiteDialect
poolclass = pool.SingletonThreadPool
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 8577b24e1..bba73ef88 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -142,6 +142,12 @@ class Table(SchemaItem, sql.TableClause):
owner=None : optional owning user of this table. useful for databases such as Oracle to aid in table
reflection.
+
+ quote=False : indicates that the Table identifier must be properly escaped and quoted before being sent
+ to the database.
+
+ quote_schema=False : indicates that the Namespace identifier must be properly escaped and quoted before being sent
+ to the database.
"""
super(Table, self).__init__(name)
self._metadata = metadata
@@ -155,6 +161,8 @@ class Table(SchemaItem, sql.TableClause):
else:
self.fullname = self.name
self.owner = kwargs.pop('owner', None)
+ self.quote = kwargs.pop('quote', False)
+ self.quote_schema = kwargs.pop('quote_schema', False)
self.kwargs = kwargs
def _set_primary_key(self, pk):
@@ -322,6 +330,8 @@ class Column(SchemaItem, sql.ColumnClause):
specify the same index name will all be included in the index, in the
order of their creation.
+ quote=False : indicates that the Column identifier must be properly escaped and quoted before being sent
+ to the database.
"""
name = str(name) # in case of incoming unicode
super(Column, self).__init__(name, None, type)
@@ -333,6 +343,7 @@ class Column(SchemaItem, sql.ColumnClause):
self.default = kwargs.pop('default', None)
self.index = kwargs.pop('index', None)
self.unique = kwargs.pop('unique', None)
+ self.quote = kwargs.pop('quote', False)
self.onupdate = kwargs.pop('onupdate', None)
if self.index is not None and self.unique is not None:
raise exceptions.ArgumentError("Column may not define both index and unique")
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 6742eac0e..18591c24c 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -1185,7 +1185,8 @@ class Label(ColumnElement):
return self.obj._get_from_objects()
def _make_proxy(self, selectable, name = None):
return self.obj._make_proxy(selectable, name=self.name)
-
+
+legal_characters = util.Set(string.ascii_letters + string.digits + '_')
class ColumnClause(ColumnElement):
"""represents a textual column clause in a SQL statement. May or may not
be bound to an underlying Selectable."""
@@ -1203,6 +1204,7 @@ class ColumnClause(ColumnElement):
self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:]
else:
self.__label = self.name
+ self.__label = "".join([x for x in self.__label if x in legal_characters])
return self.__label
_label = property(_get_label)
def accept_visitor(self, visitor):