diff options
Diffstat (limited to 'lib/sqlalchemy/engine/result.py')
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 55 |
1 files changed, 42 insertions, 13 deletions
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index d5b8057ef..cc6d26c88 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -22,6 +22,7 @@ from typing import Generic from typing import Iterable from typing import Iterator from typing import List +from typing import Mapping from typing import NoReturn from typing import Optional from typing import overload @@ -59,7 +60,7 @@ _KeyIndexType = Union[str, "Column[Any]", int] # is overridden in cursor using _CursorKeyMapRecType _KeyMapRecType = Any -_KeyMapType = Dict[_KeyType, _KeyMapRecType] +_KeyMapType = Mapping[_KeyType, _KeyMapRecType] _RowData = Union[Row, RowMapping, Any] @@ -99,6 +100,7 @@ class ResultMetaData: _keymap: _KeyMapType _keys: Sequence[str] _processors: Optional[_ProcessorsType] + _key_to_index: Mapping[_KeyType, int] @property def keys(self) -> RMKeyView: @@ -112,24 +114,27 @@ class ResultMetaData: @overload def _key_fallback( - self, key: Any, err: Exception, raiseerr: Literal[True] = ... + self, key: Any, err: Optional[Exception], raiseerr: Literal[True] = ... ) -> NoReturn: ... @overload def _key_fallback( - self, key: Any, err: Exception, raiseerr: Literal[False] = ... + self, + key: Any, + err: Optional[Exception], + raiseerr: Literal[False] = ..., ) -> None: ... @overload def _key_fallback( - self, key: Any, err: Exception, raiseerr: bool = ... + self, key: Any, err: Optional[Exception], raiseerr: bool = ... ) -> Optional[NoReturn]: ... def _key_fallback( - self, key: Any, err: Exception, raiseerr: bool = True + self, key: Any, err: Optional[Exception], raiseerr: bool = True ) -> Optional[NoReturn]: assert raiseerr raise KeyError(key) from err @@ -177,6 +182,29 @@ class ResultMetaData: indexes = self._indexes_for_keys(keys) return tuplegetter(*indexes) + def _make_key_to_index( + self, keymap: Mapping[_KeyType, Sequence[Any]], index: int + ) -> Mapping[_KeyType, int]: + return { + key: rec[index] + for key, rec in keymap.items() + if rec[index] is not None + } + + def _key_not_found(self, key: Any, attr_error: bool) -> NoReturn: + if key in self._keymap: + # the index must be none in this case + self._raise_for_ambiguous_column_name(self._keymap[key]) + else: + # unknown key + if attr_error: + try: + self._key_fallback(key, None) + except KeyError as ke: + raise AttributeError(ke.args[0]) from ke + else: + self._key_fallback(key, None) + class RMKeyView(typing.KeysView[Any]): __slots__ = ("_parent", "_keys") @@ -222,6 +250,7 @@ class SimpleResultMetaData(ResultMetaData): "_tuplefilter", "_translated_indexes", "_unique_filters", + "_key_to_index", ) _keys: Sequence[str] @@ -257,6 +286,8 @@ class SimpleResultMetaData(ResultMetaData): self._processors = _processors + self._key_to_index = self._make_key_to_index(self._keymap, 0) + def _has_key(self, key: object) -> bool: return key in self._keymap @@ -359,7 +390,7 @@ def result_tuple( ) -> Callable[[Iterable[Any]], Row[Any]]: parent = SimpleResultMetaData(fields, extra) return functools.partial( - Row, parent, parent._processors, parent._keymap, Row._default_key_style + Row, parent, parent._processors, parent._key_to_index ) @@ -424,21 +455,19 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): def process_row( # type: ignore metadata: ResultMetaData, processors: _ProcessorsType, - keymap: _KeyMapType, - key_style: Any, + key_to_index: Mapping[_KeyType, int], scalar_obj: Any, ) -> Row[Any]: return _proc( - metadata, processors, keymap, key_style, (scalar_obj,) + metadata, processors, key_to_index, (scalar_obj,) ) else: process_row = Row # type: ignore - key_style = Row._default_key_style metadata = self._metadata - keymap = metadata._keymap + key_to_index = metadata._key_to_index processors = metadata._processors tf = metadata._tuplefilter @@ -447,7 +476,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): processors = tf(processors) _make_row_orig: Callable[..., _R] = functools.partial( # type: ignore # noqa E501 - process_row, metadata, processors, keymap, key_style + process_row, metadata, processors, key_to_index ) fixed_tf = tf @@ -457,7 +486,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): else: make_row = functools.partial( # type: ignore - process_row, metadata, processors, keymap, key_style + process_row, metadata, processors, key_to_index ) fns: Tuple[Any, ...] = () |
