summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-05-23 09:07:36 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-05-23 09:07:36 -0400
commite9921ad356fee4edb56007ae39793fb2211f13cf (patch)
tree476b9d01dd3e3f04c247048ff86f8d2417fa9c84
parent04c625467e65ec4189d4fd73e0e10c727f04dce6 (diff)
downloadsqlalchemy-e9921ad356fee4edb56007ae39793fb2211f13cf.tar.gz
- fix some tests related to the URL change and try to make
the URL design a little simpler
-rw-r--r--lib/sqlalchemy/engine/strategies.py3
-rw-r--r--lib/sqlalchemy/engine/url.py16
-rw-r--r--test/engine/test_reconnect.py4
3 files changed, 17 insertions, 6 deletions
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index e2a086de4..a539ee9f7 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -48,7 +48,8 @@ class DefaultEngineStrategy(EngineStrategy):
# create url.URL object
u = url.make_url(name_or_url)
- entrypoint, dialect_cls = u._get_dialect_plus_entrypoint()
+ entrypoint = u._get_entrypoint()
+ dialect_cls = entrypoint.get_dialect_cls(u)
if kwargs.pop('_coerce_config', False):
def pop_kwarg(key, default=None):
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index 07f6a5730..32e3f8a6b 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -117,7 +117,13 @@ class URL(object):
else:
return self.drivername.split('+')[1]
- def _get_dialect_plus_entrypoint(self):
+ def _get_entrypoint(self):
+ """Return the "entry point" dialect class.
+
+ This is normally the dialect itself except in the case when the
+ returned class implements the get_dialect_cls() method.
+
+ """
if '+' not in self.drivername:
name = self.drivername
else:
@@ -129,16 +135,16 @@ class URL(object):
if hasattr(cls, 'dialect') and \
isinstance(cls.dialect, type) and \
issubclass(cls.dialect, Dialect):
- return cls.dialect, cls.dialect
+ return cls.dialect
else:
- dialect_cls = cls.get_dialect_cls(self)
- return cls, dialect_cls
+ return cls
def get_dialect(self):
"""Return the SQLAlchemy database dialect class corresponding
to this URL's driver name.
"""
- entrypoint, dialect_cls = self._get_dialect_plus_entrypoint()
+ entrypoint = self._get_entrypoint()
+ dialect_cls = entrypoint.get_dialect_cls(self)
return dialect_cls
def translate_connect_args(self, names=[], **kw):
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index 619319693..39ebcc91b 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -370,6 +370,9 @@ class MockReconnectTest(fixtures.TestBase):
mock_dialect = Mock()
class MyURL(URL):
+ def _get_entrypoint(self):
+ return Dialect
+
def get_dialect(self):
return Dialect
@@ -420,6 +423,7 @@ class CursorErrTest(fixtures.TestBase):
from sqlalchemy.engine import default
url = Mock(
get_dialect=lambda: default.DefaultDialect,
+ _get_entrypoint=lambda: default.DefaultDialect,
translate_connect_args=lambda: {}, query={},)
eng = testing_engine(
url, options=dict(module=dbapi, _initialize=initialize))