diff options
Diffstat (limited to 'lib/sqlalchemy/ext/baked.py')
| -rw-r--r-- | lib/sqlalchemy/ext/baked.py | 135 | 
1 files changed, 87 insertions, 48 deletions
| diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 516879142..f55231a09 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -38,7 +38,8 @@ class Bakery(object):      """ -    __slots__ = 'cls', 'cache' + +    __slots__ = "cls", "cache"      def __init__(self, cls_, cache):          self.cls = cls_ @@ -51,7 +52,7 @@ class Bakery(object):  class BakedQuery(object):      """A builder object for :class:`.query.Query` objects.""" -    __slots__ = 'steps', '_bakery', '_cache_key', '_spoiled' +    __slots__ = "steps", "_bakery", "_cache_key", "_spoiled"      def __init__(self, bakery, initial_fn, args=()):          self._cache_key = () @@ -148,7 +149,7 @@ class BakedQuery(object):          """          if not full and not self._spoiled:              _spoil_point = self._clone() -            _spoil_point._cache_key += ('_query_only', ) +            _spoil_point._cache_key += ("_query_only",)              self.steps = [_spoil_point._retrieve_baked_query]          self._spoiled = True          return self @@ -164,7 +165,7 @@ class BakedQuery(object):          session will want to use.          """ -        return self._cache_key + (session._query_cls, ) +        return self._cache_key + (session._query_cls,)      def _with_lazyload_options(self, options, effective_path, cache_path=None):          """Cloning version of _add_lazyload_options. @@ -201,16 +202,20 @@ class BakedQuery(object):                      key += cache_key          self.add_criteria( -            lambda q: q._with_current_path(effective_path). -            _conditional_options(*options), -            cache_path.path, key +            lambda q: q._with_current_path( +                effective_path +            )._conditional_options(*options), +            cache_path.path, +            key,          )      def _retrieve_baked_query(self, session):          query = self._bakery.get(self._effective_key(session), None)          if query is None:              query = self._as_query(session) -            self._bakery[self._effective_key(session)] = query.with_session(None) +            self._bakery[self._effective_key(session)] = query.with_session( +                None +            )          return query.with_session(session)      def _bake(self, session): @@ -227,8 +232,12 @@ class BakedQuery(object):          # so delete some compilation-use-only attributes that can take up          # space          for attr in ( -                '_correlate', '_from_obj', '_mapper_adapter_map', -                '_joinpath', '_joinpoint'): +            "_correlate", +            "_from_obj", +            "_mapper_adapter_map", +            "_joinpath", +            "_joinpoint", +        ):              query.__dict__.pop(attr, None)          self._bakery[self._effective_key(session)] = context          return context @@ -276,11 +285,13 @@ class BakedQuery(object):              session = query_or_session.session              if session is None:                  raise sa_exc.ArgumentError( -                    "Given Query needs to be associated with a Session") +                    "Given Query needs to be associated with a Session" +                )          else:              raise TypeError( -                "Query or Session object expected, got %r." % -                type(query_or_session)) +                "Query or Session object expected, got %r." +                % type(query_or_session) +            )          return self._as_query(session)      def _as_query(self, session): @@ -299,10 +310,10 @@ class BakedQuery(object):          a "baked" query so that we save on performance too.          """ -        context.attributes['baked_queries'] = baked_queries = [] +        context.attributes["baked_queries"] = baked_queries = []          for k, v in list(context.attributes.items()):              if isinstance(v, Query): -                if 'subquery' in k: +                if "subquery" in k:                      bk = BakedQuery(self._bakery, lambda *args: v)                      bk._cache_key = self._cache_key + k                      bk._bake(session) @@ -310,15 +321,17 @@ class BakedQuery(object):                  del context.attributes[k]      def _unbake_subquery_loaders( -            self, session, context, params, post_criteria): +        self, session, context, params, post_criteria +    ):          """Retrieve subquery eager loaders stored by _bake_subquery_loaders          and turn them back into Result objects that will iterate just          like a Query object.          """          for k, cache_key, query in context.attributes["baked_queries"]: -            bk = BakedQuery(self._bakery, -                            lambda sess, q=query: q.with_session(sess)) +            bk = BakedQuery( +                self._bakery, lambda sess, q=query: q.with_session(sess) +            )              bk._cache_key = cache_key              q = bk.for_session(session)              for fn in post_criteria: @@ -334,7 +347,8 @@ class Result(object):      against a target :class:`.Session`, and is then invoked for results.      """ -    __slots__ = 'bq', 'session', '_params', '_post_criteria' + +    __slots__ = "bq", "session", "_params", "_post_criteria"      def __init__(self, bq, session):          self.bq = bq @@ -350,7 +364,8 @@ class Result(object):          elif len(args) > 0:              raise sa_exc.ArgumentError(                  "params() takes zero or one positional argument, " -                "which is a dictionary.") +                "which is a dictionary." +            )          self._params.update(kw)          return self @@ -403,7 +418,8 @@ class Result(object):          context.attributes = context.attributes.copy()          bq._unbake_subquery_loaders( -            self.session, context, self._params, self._post_criteria) +            self.session, context, self._params, self._post_criteria +        )          context.statement.use_labels = True          if context.autoflush and not context.populate_existing: @@ -426,7 +442,7 @@ class Result(object):          """ -        col = func.count(literal_column('*')) +        col = func.count(literal_column("*"))          bq = self.bq.with_criteria(lambda q: q.from_self(col))          return bq.for_session(self.session).params(self._params).scalar() @@ -456,8 +472,10 @@ class Result(object):          """          bq = self.bq.with_criteria(lambda q: q.slice(0, 1))          ret = list( -            bq.for_session(self.session).params(self._params). -            _using_post_criteria(self._post_criteria)) +            bq.for_session(self.session) +            .params(self._params) +            ._using_post_criteria(self._post_criteria) +        )          if len(ret) > 0:              return ret[0]          else: @@ -473,7 +491,8 @@ class Result(object):              ret = self.one_or_none()          except orm_exc.MultipleResultsFound:              raise orm_exc.MultipleResultsFound( -                "Multiple rows were found for one()") +                "Multiple rows were found for one()" +            )          else:              if ret is None:                  raise orm_exc.NoResultFound("No row was found for one()") @@ -497,7 +516,8 @@ class Result(object):              return None          else:              raise orm_exc.MultipleResultsFound( -                "Multiple rows were found for one_or_none()") +                "Multiple rows were found for one_or_none()" +            )      def all(self):          """Return all rows. @@ -533,13 +553,18 @@ class Result(object):              # None present in ident - turn those comparisons              # into "IS NULL"              if None in primary_key_identity: -                nones = set([ -                    _get_params[col].key for col, value in -                    zip(mapper.primary_key, primary_key_identity) -                    if value is None -                ]) +                nones = set( +                    [ +                        _get_params[col].key +                        for col, value in zip( +                            mapper.primary_key, primary_key_identity +                        ) +                        if value is None +                    ] +                )                  _lcl_get_clause = sql_util.adapt_criterion_to_null( -                    _lcl_get_clause, nones) +                    _lcl_get_clause, nones +                )              _lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False)              q._criterion = _lcl_get_clause @@ -556,16 +581,20 @@ class Result(object):          # key so that if a race causes multiple calls to _get_clause,          # we've cached on ours          bq = bq._clone() -        bq._cache_key += (_get_clause, ) +        bq._cache_key += (_get_clause,)          bq = bq.with_criteria( -            setup, tuple(elem is None for elem in primary_key_identity)) +            setup, tuple(elem is None for elem in primary_key_identity) +        ) -        params = dict([ -            (_get_params[primary_key].key, id_val) -            for id_val, primary_key -            in zip(primary_key_identity, mapper.primary_key) -        ]) +        params = dict( +            [ +                (_get_params[primary_key].key, id_val) +                for id_val, primary_key in zip( +                    primary_key_identity, mapper.primary_key +                ) +            ] +        )          result = list(bq.for_session(self.session).params(**params))          l = len(result) @@ -578,7 +607,8 @@ class Result(object):  @util.deprecated( -    "1.2", "Baked lazy loading is now the default implementation.") +    "1.2", "Baked lazy loading is now the default implementation." +)  def bake_lazy_loaders():      """Enable the use of baked queries for all lazyloaders systemwide. @@ -590,7 +620,8 @@ def bake_lazy_loaders():  @util.deprecated( -    "1.2", "Baked lazy loading is now the default implementation.") +    "1.2", "Baked lazy loading is now the default implementation." +)  def unbake_lazy_loaders():      """Disable the use of baked queries for all lazyloaders systemwide. @@ -601,7 +632,8 @@ def unbake_lazy_loaders():      """      raise NotImplementedError( -        "Baked lazy loading is now the default implementation") +        "Baked lazy loading is now the default implementation" +    )  @strategy_options.loader_option() @@ -615,20 +647,27 @@ def baked_lazyload(loadopt, attr):  @baked_lazyload._add_unbound_fn  @util.deprecated( -    "1.2", "Baked lazy loading is now the default " -    "implementation for lazy loading.") +    "1.2", +    "Baked lazy loading is now the default " +    "implementation for lazy loading.", +)  def baked_lazyload(*keys):      return strategy_options._UnboundLoad._from_keys( -        strategy_options._UnboundLoad.baked_lazyload, keys, False, {}) +        strategy_options._UnboundLoad.baked_lazyload, keys, False, {} +    )  @baked_lazyload._add_unbound_all_fn  @util.deprecated( -    "1.2", "Baked lazy loading is now the default " -    "implementation for lazy loading.") +    "1.2", +    "Baked lazy loading is now the default " +    "implementation for lazy loading.", +)  def baked_lazyload_all(*keys):      return strategy_options._UnboundLoad._from_keys( -        strategy_options._UnboundLoad.baked_lazyload, keys, True, {}) +        strategy_options._UnboundLoad.baked_lazyload, keys, True, {} +    ) +  baked_lazyload = baked_lazyload._unbound_fn  baked_lazyload_all = baked_lazyload_all._unbound_all_fn | 
