From 2a079cdc76b3a0f5b4f37299d280d328586e2f7e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 11 Aug 2019 15:24:13 -0400 Subject: Rewrite pool reset_on_return parsing using a util function Choosing a util.symbol() based on a user parameter is about to have another use case added as part of #4623, so add a generalized solution ahead of it. Change-Id: I420631f81af2ffc655995b9cce9ff2ac618c16d7 --- lib/sqlalchemy/pool/base.py | 20 ++++++++++---------- lib/sqlalchemy/util/langhelpers.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 10 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 761127e83..2325e7faa 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -191,16 +191,16 @@ class Pool(log.Identified): self._recycle = recycle self._invalidate_time = 0 self._pre_ping = pre_ping - if reset_on_return in ("rollback", True, reset_rollback): - self._reset_on_return = reset_rollback - elif reset_on_return in ("none", None, False, reset_none): - self._reset_on_return = reset_none - elif reset_on_return in ("commit", reset_commit): - self._reset_on_return = reset_commit - else: - raise exc.ArgumentError( - "Invalid value for 'reset_on_return': %r" % reset_on_return - ) + self._reset_on_return = util.symbol.parse_user_argument( + reset_on_return, + { + reset_rollback: ["rollback", True], + reset_none: ["none", None, False], + reset_commit: ["commit"], + }, + "reset_on_return", + resolve_symbol_names=False, + ) self.echo = echo diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 0e0e3f4df..12fc5c0e8 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1368,6 +1368,41 @@ class symbol(object): finally: symbol._lock.release() + @classmethod + def parse_user_argument( + cls, arg, choices, name, resolve_symbol_names=False + ): + """Given a user parameter, parse the parameter into a chosen symbol. + + The user argument can be a string name that matches the name of a + symbol, or the symbol object itself, or any number of alternate choices + such as True/False/ None etc. + + :param arg: the user argument. + :param choices: dictionary of symbol object to list of possible + entries. + :param name: name of the argument. Used in an :class:`.ArgumentError` + that is raised if the parameter doesn't match any available argument. + :param resolve_symbol_names: include the name of each symbol as a valid + entry. + + """ + # note using hash lookup is tricky here because symbol's `__hash__` + # is its int value which we don't want included in the lookup + # explicitly, so we iterate and compare each. + for sym, choice in choices.items(): + if arg is sym: + return sym + elif resolve_symbol_names and arg == sym.name: + return sym + elif arg in choice: + return sym + + if arg is None: + return None + + raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) + _creation_order = 1 -- cgit v1.2.1