diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-04-18 14:03:45 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-04-18 14:03:45 -0400 |
| commit | d2037c31a371b65bad1ee72800af8bb2f5ad9f87 (patch) | |
| tree | b94160fb08922fa6d343388692d8487c3cfeffed | |
| parent | bd61e7a3287079cf742f4df698bfe3628c090522 (diff) | |
| download | sqlalchemy-d2037c31a371b65bad1ee72800af8bb2f5ad9f87.tar.gz | |
- Added new event :class:`.DialectEvents.do_connect`, which allows
interception / replacement of when the :meth:`.Dialect.connect`
hook is called to create a DBAPI connection. Also added
dialect plugin hooks :meth:`.Dialect.get_dialect_cls` and
:meth:`.Dialect.engine_created` which allow external plugins to
add events to existing dialects using entry points.
fixes #3355
| -rw-r--r-- | doc/build/changelog/changelog_10.rst | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/strategies.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/events.py | 17 | ||||
| -rw-r--r-- | test/engine/test_execute.py | 39 | ||||
| -rw-r--r-- | test/engine/test_parseconnect.py | 29 |
6 files changed, 142 insertions, 2 deletions
diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst index cff3f1b5c..bca26d307 100644 --- a/doc/build/changelog/changelog_10.rst +++ b/doc/build/changelog/changelog_10.rst @@ -19,6 +19,17 @@ :version: 1.0.1 .. change:: + :tags: feature, engine + :tickets: 3355 + + Added new event :class:`.DialectEvents.do_connect`, which allows + interception / replacement of when the :meth:`.Dialect.connect` + hook is called to create a DBAPI connection. Also added + dialect plugin hooks :meth:`.Dialect.get_dialect_cls` and + :meth:`.Dialect.engine_created` which allow external plugins to + add events to existing dialects using entry points. + + .. change:: :tags: bug, orm :tickets: 3368 diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index da8fa81eb..ecb9a8489 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -733,6 +733,41 @@ class Dialect(object): raise NotImplementedError() + @classmethod + def get_dialect_cls(cls, url): + """Given a URL, return the :class:`.Dialect` that will be used. + + This is a hook that allows an external plugin to provide functionality + around an existing dialect, by allowing the plugin to be loaded + from the url based on an entrypoint, and then the plugin returns + the actual dialect to be used. + + By default this just returns the cls. + + .. versionadded:: 1.0.1 + + """ + return cls + + @classmethod + def engine_created(cls, engine): + """A convenience hook called before returning the final :class:`.Engine`. + + If the dialect returned a different class from the + :meth:`.get_dialect_cls` + method, then the hook is called on both classes, first on + the dialect class returned by the :meth:`.get_dialect_cls` method and + then on the class on which the method was called. + + The hook should be used by dialects and/or wrappers to apply special + events to the engine or its components. In particular, it allows + a dialect-wrapping class to apply dialect-level events. + + .. versionadded:: 1.0.1 + + """ + pass + class ExecutionContext(object): """A messenger object for a Dialect that corresponds to a single diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 1fd105d67..66de782b8 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -48,7 +48,8 @@ class DefaultEngineStrategy(EngineStrategy): # create url.URL object u = url.make_url(name_or_url) - dialect_cls = u.get_dialect() + entrypoint = u.get_dialect() + dialect_cls = entrypoint.get_dialect_cls(u) if kwargs.pop('_coerce_config', False): def pop_kwarg(key, default=None): @@ -81,11 +82,17 @@ class DefaultEngineStrategy(EngineStrategy): # assemble connection arguments (cargs, cparams) = dialect.create_connect_args(u) cparams.update(pop_kwarg('connect_args', {})) + cargs = list(cargs) # allow mutability # look for existing pool or create pool = pop_kwarg('pool', None) if pool is None: def connect(): + if dialect._has_events: + for fn in dialect.dispatch.do_connect: + connection = fn(dialect, cargs, cparams) + if connection is not None: + return connection return dialect.connect(*cargs, **cparams) creator = pop_kwarg('creator', connect) @@ -157,6 +164,10 @@ class DefaultEngineStrategy(EngineStrategy): dialect.initialize(c) event.listen(pool, 'first_connect', first_connect, once=True) + dialect_cls.engine_created(engine) + if entrypoint is not dialect_cls: + entrypoint.engine_created(engine) + return engine diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 22e066c88..7f5e89ebb 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -1007,6 +1007,23 @@ class DialectEvents(event.Events): else: return target + def do_connect(self, dialect, cargs, cparams): + """Receive connection arguments before a connection is made. + + Return a DBAPI connection to halt further events from invoking; + the returned connection will be used. + + Alternatively, the event can manipulate the cargs and/or cparams + collections; cargs will always be a Python list that can be mutated + in-place and cparams a Python dictionary. Return None to + allow control to pass to the next event handler and ultimately + to allow the dialect to connect normally, given the updated + arguments. + + .. versionadded:: 1.0.1 + + """ + def do_executemany(self, cursor, statement, parameters, context): """Receive a cursor to have executemany() called. diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index b0256d325..dbab21cd7 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -2532,3 +2532,42 @@ class DialectEventTest(fixtures.TestBase): def test_cursor_execute_wo_replace(self): self._test_cursor_execute(False) + + def test_connect_replace_params(self): + e = engines.testing_engine(options={"_initialize": False}) + + @event.listens_for(e, "do_connect") + def evt(dialect, cargs, cparams): + cargs[:] = ['foo', 'hoho'] + cparams.clear() + cparams['bar'] = 'bat' + + m1 = Mock() + e.dialect.connect = m1.real_connect + + with e.connect(): + eq_(m1.mock_calls, [call.real_connect('foo', 'hoho', bar='bat')]) + + def test_connect_do_connect(self): + e = engines.testing_engine(options={"_initialize": False}) + + m1 = Mock() + + @event.listens_for(e, "do_connect") + def evt1(dialect, cargs, cparams): + cargs[:] = ['foo', 'hoho'] + cparams.clear() + cparams['bar'] = 'bat' + + @event.listens_for(e, "do_connect") + def evt2(dialect, cargs, cparams): + return m1.our_connect(cargs, cparams) + + with e.connect() as conn: + # called with args + eq_( + m1.mock_calls, + [call.our_connect(['foo', 'hoho'], {'bar': 'bat'})]) + + # returned our mock connection + is_(conn.connection.connection, m1.our_connect()) diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index e53a99e15..98994e2de 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -5,7 +5,7 @@ from sqlalchemy.engine.default import DefaultDialect import sqlalchemy as tsa from sqlalchemy.testing import fixtures from sqlalchemy import testing -from sqlalchemy.testing.mock import Mock, MagicMock +from sqlalchemy.testing.mock import Mock, MagicMock, call from sqlalchemy import event from sqlalchemy import select @@ -325,6 +325,33 @@ class TestRegNewDBAPI(fixtures.TestBase): e = create_engine("mysql+my_mock_dialect://") assert isinstance(e.dialect, MockDialect) + @testing.requires.sqlite + def test_wrapper_hooks(self): + def get_dialect_cls(url): + url.drivername = "sqlite" + return url.get_dialect() + + global WrapperFactory + WrapperFactory = Mock() + WrapperFactory.get_dialect_cls.side_effect = get_dialect_cls + + from sqlalchemy.dialects import registry + registry.register("wrapperdialect", __name__, "WrapperFactory") + + from sqlalchemy.dialects import sqlite + e = create_engine("wrapperdialect://") + + eq_(e.dialect.name, "sqlite") + assert isinstance(e.dialect, sqlite.dialect) + + eq_( + WrapperFactory.mock_calls, + [ + call.get_dialect_cls(url.make_url("sqlite://")), + call.engine_created(e) + ] + ) + class MockDialect(DefaultDialect): @classmethod |
