diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-10-03 17:36:27 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-10-07 23:06:06 -0400 |
| commit | 65aee6cce57fd1cca3a95814feff3ed99a5a51ee (patch) | |
| tree | 0352d74938902a9242dfb97ca5215d9191a2ad16 /lib/sqlalchemy | |
| parent | ebd9788c986c56b8b845fa83609a6eb2c0cef083 (diff) | |
| download | sqlalchemy-65aee6cce57fd1cca3a95814feff3ed99a5a51ee.tar.gz | |
Add result map targeting for custom compiled, text objects
In order for text(), custom compiled objects, etc. to be usable
by Query(), they are all targeted by object key in the result map.
As we no longer want Query to implicitly label these, as well as that
text() has no label feature, support adding entries to the result
map that have no name, key, or type, only the object itself, and
then ensure that the compiler sets up for positional targeting
when this condition is detected.
Allows for more flexible ORM query usage with custom expressions
and text() while having less special logic in query itself.
Fixes: #4887
Change-Id: Ie073da127d292d43cb132a2b31bc90af88bfe2fd
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 62 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/compiler.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 72 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 12 |
4 files changed, 111 insertions, 56 deletions
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index af5303658..733bd6f6a 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -321,40 +321,41 @@ class ResultMetaData(object): # dupe records with "None" for index which results in # ambiguous column exception when accessed. if len(by_key) != num_ctx_cols: - seen = set() + # new in 1.4: get the complete set of all possible keys, + # strings, objects, whatever, that are dupes across two + # different records, first. + index_by_key = {} + dupes = set() for metadata_entry in raw: - key = metadata_entry[MD_RENDERED_NAME] - if key in seen: - # this is an "ambiguous" element, replacing - # the full record in the map - key = key.lower() if not self.case_sensitive else key - by_key[key] = (None, (), key) - seen.add(key) - - # copy secondary elements from compiled columns - # into self._keymap, write in the potentially "ambiguous" - # element + for key in (metadata_entry[MD_RENDERED_NAME],) + ( + metadata_entry[MD_OBJECTS] or () + ): + if not self.case_sensitive and isinstance( + key, util.string_types + ): + key = key.lower() + idx = metadata_entry[MD_INDEX] + # if this key has been associated with more than one + # positional index, it's a dupe + if index_by_key.setdefault(key, idx) != idx: + dupes.add(key) + + # then put everything we have into the keymap excluding only + # those keys that are dupes. self._keymap.update( [ - (obj_elem, by_key[metadata_entry[MD_LOOKUP_KEY]]) + (obj_elem, metadata_entry) for metadata_entry in raw if metadata_entry[MD_OBJECTS] for obj_elem in metadata_entry[MD_OBJECTS] + if obj_elem not in dupes ] ) - # if we did a pure positional match, then reset the - # original "expression element" back to the "unambiguous" - # entry. This is a new behavior in 1.1 which impacts - # TextualSelect but also straight compiled SQL constructs. - if not self.matched_on_name: - self._keymap.update( - [ - (metadata_entry[MD_OBJECTS][0], metadata_entry) - for metadata_entry in raw - if metadata_entry[MD_OBJECTS] - ] - ) + # then for the dupe keys, put the "ambiguous column" + # record into by_key. + by_key.update({key: (None, (), key) for key in dupes}) + else: # no dupes - copy secondary elements from compiled # columns into self._keymap @@ -502,16 +503,16 @@ class ResultMetaData(object): ( idx, obj, - colname, - colname, + cursor_colname, + cursor_colname, context.get_result_processor( - mapped_type, colname, coltype + mapped_type, cursor_colname, coltype ), untranslated, ) for ( idx, - colname, + cursor_colname, mapped_type, coltype, obj, @@ -592,7 +593,6 @@ class ResultMetaData(object): else: mapped_type = sqltypes.NULLTYPE obj = None - yield idx, colname, mapped_type, coltype, obj, untranslated def _merge_cols_by_name( @@ -758,7 +758,7 @@ class ResultMetaData(object): if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj + "result set column descriptions" % rec[MD_LOOKUP_KEY] ) return operator.methodcaller("_get_by_key_impl", index) diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 572c62b8e..4a5a8ba9c 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -398,6 +398,7 @@ Example usage:: """ from .. import exc +from ..sql import sqltypes from ..sql import visitors @@ -475,4 +476,22 @@ class _dispatcher(object): "compilation handler." % type(element) ) - return fn(element, compiler, **kw) + # if compilation includes add_to_result_map, collect add_to_result_map + # arguments from the user-defined callable, which are probably none + # because this is not public API. if it wasn't called, then call it + # ourselves. + arm = kw.get("add_to_result_map", None) + if arm: + arm_collection = [] + kw["add_to_result_map"] = lambda *args: arm_collection.append(args) + + expr = fn(element, compiler, **kw) + + if arm: + if not arm_collection: + arm_collection.append( + (None, None, (element,), sqltypes.NULLTYPE) + ) + for tup in arm_collection: + arm(*tup) + return expr diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 320c7b782..453ff56d2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -871,12 +871,11 @@ class SQLCompiler(Compiled): name = self._truncated_identifier("colident", name) if add_to_result_map is not None: - add_to_result_map( - name, - orig_name, - (column, name, column.key, column._label) + result_map_targets, - column.type, - ) + targets = (column, name, column.key) + result_map_targets + if column._label: + targets += (column._label,) + + add_to_result_map(name, orig_name, targets, column.type) if is_literal: # note we are not currently accommodating for @@ -925,7 +924,7 @@ class SQLCompiler(Compiled): text = text.replace("%", "%%") return text - def visit_textclause(self, textclause, **kw): + def visit_textclause(self, textclause, add_to_result_map=None, **kw): def do_bindparam(m): name = m.group(1) if name in textclause._bindparams: @@ -936,6 +935,12 @@ class SQLCompiler(Compiled): if not self.stack: self.isplaintext = True + if add_to_result_map: + # text() object is present in the columns clause of a + # select(). Add a no-name entry to the result map so that + # row[text()] produces a result + add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE) + # un-escape any \:params return BIND_PARAMS_ESC.sub( lambda m: m.group(1), @@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled): return " AS " + alias_name_text def _add_to_result_map(self, keyname, name, objects, type_): + if keyname is None: + self._ordered_columns = False + self._textual_ordered_columns = True self._result_columns.append((keyname, name, objects, type_)) def _label_select_column( @@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled): column_clause_args, name=None, within_columns_clause=True, + column_is_repeated=False, need_column_expressions=False, ): """produce labeled columns present in a select().""" @@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled): need_column_expressions or populate_result_map ): col_expr = impl.column_expression(column) + else: + col_expr = column - if populate_result_map: + if populate_result_map: + # pass an "add_to_result_map" callable into the compilation + # of embedded columns. this collects information about the + # column as it will be fetched in the result and is coordinated + # with cursor.description when the query is executed. + add_to_result_map = self._add_to_result_map + + # if the SELECT statement told us this column is a repeat, + # wrap the callable with one that prevents the addition of the + # targets + if column_is_repeated: + _add_to_result_map = add_to_result_map def add_to_result_map(keyname, name, objects, type_): - self._add_to_result_map( + _add_to_result_map(keyname, name, (), type_) + + # if we redefined col_expr for type expressions, wrap the + # callable with one that adds the original column to the targets + elif col_expr is not column: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map( keyname, name, (column,) + objects, type_ ) - else: - add_to_result_map = None else: - col_expr = column - if populate_result_map: - add_to_result_map = self._add_to_result_map - else: - add_to_result_map = None + add_to_result_map = None if not within_columns_clause: result_expr = col_expr @@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled): ) and ( not hasattr(column, "name") - or isinstance(column, functions.Function) + or isinstance(column, functions.FunctionElement) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) @@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled): asfrom, column_clause_args, name=name, + column_is_repeated=repeated, need_column_expressions=need_column_expressions, ) - for name, column in select._columns_plus_names + for name, column, repeated in select._columns_plus_names ] if c is not None ] @@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled): translate = dict( zip( - [name for (key, name) in select._columns_plus_names], [ name - for (key, name) in select_wraps_for._columns_plus_names + for (key, name, repeated) in select._columns_plus_names + ], + [ + name + for ( + key, + name, + repeated, + ) in select_wraps_for._columns_plus_names ], ) ) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index ddbcdf91d..6282cf2ee 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4191,8 +4191,9 @@ class Select( def name_for_col(c): if c._label is None or not c._render_label_in_columns_clause: - return (None, c) + return (None, c, False) + repeated = False name = c._label if name in names: @@ -4218,19 +4219,22 @@ class Select( # subsequent occurrences of the column so that the # original stays non-ambiguous name = c._dedupe_label_anon_label + repeated = True else: names[name] = c elif anon_for_dupe_key: # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous name = c._dedupe_label_anon_label + repeated = True else: names[name] = c - return name, c + return name, c, repeated return [name_for_col(c) for c in cols] else: - return [(None, c) for c in cols] + # repeated name logic only for use labels at the moment + return [(None, c, False) for c in cols] @_memoized_property def _columns_plus_names(self): @@ -4245,7 +4249,7 @@ class Select( keys_seen = set() prox = [] - for name, c in self._generate_columns_plus_names(False): + for name, c, repeated in self._generate_columns_plus_names(False): if not hasattr(c, "_make_proxy"): continue if name is None: |
