summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-04-18 14:03:45 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-04-18 14:03:45 -0400
commitd2037c31a371b65bad1ee72800af8bb2f5ad9f87 (patch)
treeb94160fb08922fa6d343388692d8487c3cfeffed
parentbd61e7a3287079cf742f4df698bfe3628c090522 (diff)
downloadsqlalchemy-d2037c31a371b65bad1ee72800af8bb2f5ad9f87.tar.gz
- Added new event :class:`.DialectEvents.do_connect`, which allows
interception / replacement of when the :meth:`.Dialect.connect` hook is called to create a DBAPI connection. Also added dialect plugin hooks :meth:`.Dialect.get_dialect_cls` and :meth:`.Dialect.engine_created` which allow external plugins to add events to existing dialects using entry points. fixes #3355
-rw-r--r--doc/build/changelog/changelog_10.rst11
-rw-r--r--lib/sqlalchemy/engine/interfaces.py35
-rw-r--r--lib/sqlalchemy/engine/strategies.py13
-rw-r--r--lib/sqlalchemy/events.py17
-rw-r--r--test/engine/test_execute.py39
-rw-r--r--test/engine/test_parseconnect.py29
6 files changed, 142 insertions, 2 deletions
diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst
index cff3f1b5c..bca26d307 100644
--- a/doc/build/changelog/changelog_10.rst
+++ b/doc/build/changelog/changelog_10.rst
@@ -19,6 +19,17 @@
:version: 1.0.1
.. change::
+ :tags: feature, engine
+ :tickets: 3355
+
+ Added new event :class:`.DialectEvents.do_connect`, which allows
+ interception / replacement of when the :meth:`.Dialect.connect`
+ hook is called to create a DBAPI connection. Also added
+ dialect plugin hooks :meth:`.Dialect.get_dialect_cls` and
+ :meth:`.Dialect.engine_created` which allow external plugins to
+ add events to existing dialects using entry points.
+
+ .. change::
:tags: bug, orm
:tickets: 3368
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index da8fa81eb..ecb9a8489 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -733,6 +733,41 @@ class Dialect(object):
raise NotImplementedError()
+ @classmethod
+ def get_dialect_cls(cls, url):
+ """Given a URL, return the :class:`.Dialect` that will be used.
+
+ This is a hook that allows an external plugin to provide functionality
+ around an existing dialect, by allowing the plugin to be loaded
+ from the url based on an entrypoint, and then the plugin returns
+ the actual dialect to be used.
+
+ By default this just returns the cls.
+
+ .. versionadded:: 1.0.1
+
+ """
+ return cls
+
+ @classmethod
+ def engine_created(cls, engine):
+ """A convenience hook called before returning the final :class:`.Engine`.
+
+ If the dialect returned a different class from the
+ :meth:`.get_dialect_cls`
+ method, then the hook is called on both classes, first on
+ the dialect class returned by the :meth:`.get_dialect_cls` method and
+ then on the class on which the method was called.
+
+ The hook should be used by dialects and/or wrappers to apply special
+ events to the engine or its components. In particular, it allows
+ a dialect-wrapping class to apply dialect-level events.
+
+ .. versionadded:: 1.0.1
+
+ """
+ pass
+
class ExecutionContext(object):
"""A messenger object for a Dialect that corresponds to a single
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index 1fd105d67..66de782b8 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -48,7 +48,8 @@ class DefaultEngineStrategy(EngineStrategy):
# create url.URL object
u = url.make_url(name_or_url)
- dialect_cls = u.get_dialect()
+ entrypoint = u.get_dialect()
+ dialect_cls = entrypoint.get_dialect_cls(u)
if kwargs.pop('_coerce_config', False):
def pop_kwarg(key, default=None):
@@ -81,11 +82,17 @@ class DefaultEngineStrategy(EngineStrategy):
# assemble connection arguments
(cargs, cparams) = dialect.create_connect_args(u)
cparams.update(pop_kwarg('connect_args', {}))
+ cargs = list(cargs) # allow mutability
# look for existing pool or create
pool = pop_kwarg('pool', None)
if pool is None:
def connect():
+ if dialect._has_events:
+ for fn in dialect.dispatch.do_connect:
+ connection = fn(dialect, cargs, cparams)
+ if connection is not None:
+ return connection
return dialect.connect(*cargs, **cparams)
creator = pop_kwarg('creator', connect)
@@ -157,6 +164,10 @@ class DefaultEngineStrategy(EngineStrategy):
dialect.initialize(c)
event.listen(pool, 'first_connect', first_connect, once=True)
+ dialect_cls.engine_created(engine)
+ if entrypoint is not dialect_cls:
+ entrypoint.engine_created(engine)
+
return engine
diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py
index 22e066c88..7f5e89ebb 100644
--- a/lib/sqlalchemy/events.py
+++ b/lib/sqlalchemy/events.py
@@ -1007,6 +1007,23 @@ class DialectEvents(event.Events):
else:
return target
+ def do_connect(self, dialect, cargs, cparams):
+ """Receive connection arguments before a connection is made.
+
+ Return a DBAPI connection to halt further events from invoking;
+ the returned connection will be used.
+
+ Alternatively, the event can manipulate the cargs and/or cparams
+ collections; cargs will always be a Python list that can be mutated
+ in-place and cparams a Python dictionary. Return None to
+ allow control to pass to the next event handler and ultimately
+ to allow the dialect to connect normally, given the updated
+ arguments.
+
+ .. versionadded:: 1.0.1
+
+ """
+
def do_executemany(self, cursor, statement, parameters, context):
"""Receive a cursor to have executemany() called.
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index b0256d325..dbab21cd7 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -2532,3 +2532,42 @@ class DialectEventTest(fixtures.TestBase):
def test_cursor_execute_wo_replace(self):
self._test_cursor_execute(False)
+
+ def test_connect_replace_params(self):
+ e = engines.testing_engine(options={"_initialize": False})
+
+ @event.listens_for(e, "do_connect")
+ def evt(dialect, cargs, cparams):
+ cargs[:] = ['foo', 'hoho']
+ cparams.clear()
+ cparams['bar'] = 'bat'
+
+ m1 = Mock()
+ e.dialect.connect = m1.real_connect
+
+ with e.connect():
+ eq_(m1.mock_calls, [call.real_connect('foo', 'hoho', bar='bat')])
+
+ def test_connect_do_connect(self):
+ e = engines.testing_engine(options={"_initialize": False})
+
+ m1 = Mock()
+
+ @event.listens_for(e, "do_connect")
+ def evt1(dialect, cargs, cparams):
+ cargs[:] = ['foo', 'hoho']
+ cparams.clear()
+ cparams['bar'] = 'bat'
+
+ @event.listens_for(e, "do_connect")
+ def evt2(dialect, cargs, cparams):
+ return m1.our_connect(cargs, cparams)
+
+ with e.connect() as conn:
+ # called with args
+ eq_(
+ m1.mock_calls,
+ [call.our_connect(['foo', 'hoho'], {'bar': 'bat'})])
+
+ # returned our mock connection
+ is_(conn.connection.connection, m1.our_connect())
diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py
index e53a99e15..98994e2de 100644
--- a/test/engine/test_parseconnect.py
+++ b/test/engine/test_parseconnect.py
@@ -5,7 +5,7 @@ from sqlalchemy.engine.default import DefaultDialect
import sqlalchemy as tsa
from sqlalchemy.testing import fixtures
from sqlalchemy import testing
-from sqlalchemy.testing.mock import Mock, MagicMock
+from sqlalchemy.testing.mock import Mock, MagicMock, call
from sqlalchemy import event
from sqlalchemy import select
@@ -325,6 +325,33 @@ class TestRegNewDBAPI(fixtures.TestBase):
e = create_engine("mysql+my_mock_dialect://")
assert isinstance(e.dialect, MockDialect)
+ @testing.requires.sqlite
+ def test_wrapper_hooks(self):
+ def get_dialect_cls(url):
+ url.drivername = "sqlite"
+ return url.get_dialect()
+
+ global WrapperFactory
+ WrapperFactory = Mock()
+ WrapperFactory.get_dialect_cls.side_effect = get_dialect_cls
+
+ from sqlalchemy.dialects import registry
+ registry.register("wrapperdialect", __name__, "WrapperFactory")
+
+ from sqlalchemy.dialects import sqlite
+ e = create_engine("wrapperdialect://")
+
+ eq_(e.dialect.name, "sqlite")
+ assert isinstance(e.dialect, sqlite.dialect)
+
+ eq_(
+ WrapperFactory.mock_calls,
+ [
+ call.get_dialect_cls(url.make_url("sqlite://")),
+ call.engine_created(e)
+ ]
+ )
+
class MockDialect(DefaultDialect):
@classmethod