summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2013-12-18 18:26:15 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2013-12-18 18:26:15 -0500
commit2692238f45ae4d2f46949dfa52b16132bd266e0e (patch)
tree3ea336d98a01461b51da530bf0aca155b891fdd0
parentf701f87c1374c1e4d80b9f47c17632518cece765 (diff)
downloadsqlalchemy-2692238f45ae4d2f46949dfa52b16132bd266e0e.tar.gz
- Improvements to the system by which SQL types generate within
``__repr__()``, particularly with regards to the MySQL integer/numeric/ character types which feature a wide variety of keyword arguments. The ``__repr__()`` is important for use with Alembic autogenerate for when Python code is rendered in a migration script. [ticket:2893]
-rw-r--r--doc/build/changelog/changelog_09.rst10
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py35
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py19
-rw-r--r--lib/sqlalchemy/util/langhelpers.py85
-rw-r--r--test/base/test_utils.py49
-rw-r--r--test/dialect/mysql/test_types.py21
-rw-r--r--test/sql/test_types.py8
7 files changed, 175 insertions, 52 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst
index 335024016..3106eefb8 100644
--- a/doc/build/changelog/changelog_09.rst
+++ b/doc/build/changelog/changelog_09.rst
@@ -15,6 +15,16 @@
:version: 0.9.0b2
.. change::
+ :tags: bug, mysql
+ :tickets: 2893
+
+ Improvements to the system by which SQL types generate within
+ ``__repr__()``, particularly with regards to the MySQL integer/numeric/
+ character types which feature a wide variety of keyword arguments.
+ The ``__repr__()`` is important for use with Alembic autogenerate
+ for when Python code is rendered in a migration script.
+
+ .. change::
:tags: feature, postgresql
:tickets: 2581
:pullreq: github:50
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 971005a84..cc906b111 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -380,13 +380,21 @@ SET_RE = re.compile(
class _NumericType(object):
- """Base for MySQL numeric types."""
+ """Base for MySQL numeric types.
+
+ This is the base both for NUMERIC as well as INTEGER, hence
+ it's a mixin.
+
+ """
def __init__(self, unsigned=False, zerofill=False, **kw):
self.unsigned = unsigned
self.zerofill = zerofill
super(_NumericType, self).__init__(**kw)
+ def __repr__(self):
+ return util.generic_repr(self,
+ to_inspect=[_NumericType, sqltypes.Numeric])
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
@@ -401,18 +409,24 @@ class _FloatType(_NumericType, sqltypes.Float):
super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw)
self.scale = scale
+ def __repr__(self):
+ return util.generic_repr(self,
+ to_inspect=[_FloatType, _NumericType, sqltypes.Float])
class _IntegerType(_NumericType, sqltypes.Integer):
def __init__(self, display_width=None, **kw):
self.display_width = display_width
super(_IntegerType, self).__init__(**kw)
+ def __repr__(self):
+ return util.generic_repr(self,
+ to_inspect=[_IntegerType, _NumericType, sqltypes.Integer])
class _StringType(sqltypes.String):
"""Base for MySQL string types."""
def __init__(self, charset=None, collation=None,
- ascii=False, binary=False,
+ ascii=False, binary=False, unicode=False,
national=False, **kw):
self.charset = charset
@@ -420,16 +434,14 @@ class _StringType(sqltypes.String):
kw.setdefault('collation', kw.pop('collate', collation))
self.ascii = ascii
- # We have to munge the 'unicode' param strictly as a dict
- # otherwise 2to3 will turn it into str.
- self.__dict__['unicode'] = kw.get('unicode', False)
- # sqltypes.String does not accept the 'unicode' arg at all.
- if 'unicode' in kw:
- del kw['unicode']
+ self.unicode = unicode
self.binary = binary
self.national = national
super(_StringType, self).__init__(**kw)
+ def __repr__(self):
+ return util.generic_repr(self,
+ to_inspect=[_StringType, sqltypes.String])
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""MySQL NUMERIC type."""
@@ -1141,6 +1153,10 @@ class ENUM(sqltypes.Enum, _EnumeratedValues):
_StringType.__init__(self, length=length, **kw)
sqltypes.Enum.__init__(self, *values)
+ def __repr__(self):
+ return util.generic_repr(self,
+ to_inspect=[ENUM, _StringType, sqltypes.Enum])
+
def bind_processor(self, dialect):
super_convert = super(ENUM, self).bind_processor(dialect)
@@ -1287,6 +1303,9 @@ MSFloat = FLOAT
MSInteger = INTEGER
colspecs = {
+ _IntegerType: _IntegerType,
+ _NumericType: _NumericType,
+ _FloatType: _FloatType,
sqltypes.Numeric: NUMERIC,
sqltypes.Float: FLOAT,
sqltypes.Time: TIME,
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 6ed20084b..259749cc4 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -906,15 +906,15 @@ class SchemaType(SchemaEventTarget):
"""
- def __init__(self, **kw):
- name = kw.pop('name', None)
+ def __init__(self, name=None, schema=None, metadata=None,
+ inherit_schema=False, quote=None):
if name is not None:
- self.name = quoted_name(name, kw.pop('quote', None))
+ self.name = quoted_name(name, quote)
else:
self.name = None
- self.schema = kw.pop('schema', None)
- self.metadata = kw.pop('metadata', None)
- self.inherit_schema = kw.pop('inherit_schema', False)
+ self.schema = schema
+ self.metadata = metadata
+ self.inherit_schema = inherit_schema
if self.metadata:
event.listen(
self.metadata,
@@ -1110,10 +1110,9 @@ class Enum(String, SchemaType):
SchemaType.__init__(self, **kw)
def __repr__(self):
- return util.generic_repr(self, [
- ("native_enum", True),
- ("name", None)
- ])
+ return util.generic_repr(self,
+ to_inspect=[Enum, SchemaType],
+ )
def _should_create_constraint(self, compiler):
return not self.native_enum or \
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 8cf3db1bb..1a66426d0 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -19,6 +19,7 @@ from functools import update_wrapper
from .. import exc
import hashlib
from . import compat
+from . import _collections
def md5_hex(x):
if compat.py3k:
@@ -392,44 +393,66 @@ def generic_repr(obj, additional_kw=(), to_inspect=None):
"""
if to_inspect is None:
- to_inspect = obj
+ to_inspect = [obj]
+ else:
+ to_inspect = _collections.to_list(to_inspect)
missing = object()
- def genargs():
+ pos_args = []
+ kw_args = _collections.OrderedDict()
+ vargs = None
+ for i, insp in enumerate(to_inspect):
try:
- (args, vargs, vkw, defaults) = \
- inspect.getargspec(to_inspect.__init__)
+ (_args, _vargs, vkw, defaults) = \
+ inspect.getargspec(insp.__init__)
except TypeError:
- return
+ continue
+ else:
+ default_len = defaults and len(defaults) or 0
+ if i == 0:
+ if _vargs:
+ vargs = _vargs
+ if default_len:
+ pos_args.extend(_args[1:-default_len])
+ else:
+ pos_args.extend(_args[1:])
+ else:
+ kw_args.update([
+ (arg, missing) for arg in _args[1:-default_len]
+ ])
- default_len = defaults and len(defaults) or 0
+ if default_len:
+ kw_args.update([
+ (arg, default)
+ for arg, default
+ in zip(_args[-default_len:], defaults)
+ ])
+ output = []
- if not default_len:
- for arg in args[1:]:
- yield repr(getattr(obj, arg, None))
- if vargs is not None and hasattr(obj, vargs):
- yield ', '.join(repr(val) for val in getattr(obj, vargs))
- else:
- for arg in args[1:-default_len]:
- yield repr(getattr(obj, arg, None))
- for (arg, defval) in zip(args[-default_len:], defaults):
- try:
- val = getattr(obj, arg, missing)
- if val is not missing and val != defval:
- yield '%s=%r' % (arg, val)
- except:
- pass
- if additional_kw:
- for arg, defval in additional_kw:
- try:
- val = getattr(obj, arg, missing)
- if val is not missing and val != defval:
- yield '%s=%r' % (arg, val)
- except:
- pass
-
- return "%s(%s)" % (obj.__class__.__name__, ", ".join(genargs()))
+ output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
+
+ if vargs is not None and hasattr(obj, vargs):
+ output.extend([repr(val) for val in getattr(obj, vargs)])
+
+ for arg, defval in kw_args.items():
+ try:
+ val = getattr(obj, arg, missing)
+ if val is not missing and val != defval:
+ output.append('%s=%r' % (arg, val))
+ except:
+ pass
+
+ if additional_kw:
+ for arg, defval in additional_kw:
+ try:
+ val = getattr(obj, arg, missing)
+ if val is not missing and val != defval:
+ output.append('%s=%r' % (arg, val))
+ except:
+ pass
+
+ return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
class portable_instancemethod(object):
diff --git a/test/base/test_utils.py b/test/base/test_utils.py
index 2fd1edbb5..1946bd704 100644
--- a/test/base/test_utils.py
+++ b/test/base/test_utils.py
@@ -1419,6 +1419,55 @@ class GenericReprTest(fixtures.TestBase):
"Foo(b=5, d=7)"
)
+ def test_multi_kw(self):
+ class Foo(object):
+ def __init__(self, a, b, c=3, d=4):
+ self.a = a
+ self.b = b
+ self.c = c
+ self.d = d
+ class Bar(Foo):
+ def __init__(self, e, f, g=5, **kw):
+ self.e = e
+ self.f = f
+ self.g = g
+ super(Bar, self).__init__(**kw)
+
+ eq_(
+ util.generic_repr(
+ Bar('e', 'f', g=7, a=6, b=5, d=9),
+ to_inspect=[Bar, Foo]
+ ),
+ "Bar('e', 'f', g=7, a=6, b=5, d=9)"
+ )
+
+ eq_(
+ util.generic_repr(
+ Bar('e', 'f', a=6, b=5),
+ to_inspect=[Bar, Foo]
+ ),
+ "Bar('e', 'f', a=6, b=5)"
+ )
+
+ def test_multi_kw_repeated(self):
+ class Foo(object):
+ def __init__(self, a=1, b=2):
+ self.a = a
+ self.b = b
+ class Bar(Foo):
+ def __init__(self, b=3, c=4, **kw):
+ self.c = c
+ super(Bar, self).__init__(b=b, **kw)
+
+ eq_(
+ util.generic_repr(
+ Bar(a='a', b='b', c='c'),
+ to_inspect=[Bar, Foo]
+ ),
+ "Bar(b='b', c='c', a='a')"
+ )
+
+
def test_discard_vargs(self):
class Foo(object):
def __init__(self, a, b, *args):
diff --git a/test/dialect/mysql/test_types.py b/test/dialect/mysql/test_types.py
index 071b8440f..cd6cba18e 100644
--- a/test/dialect/mysql/test_types.py
+++ b/test/dialect/mysql/test_types.py
@@ -142,8 +142,15 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
]
for type_, args, kw, res in columns:
+ type_inst = type_(*args, **kw)
self.assert_compile(
- type_(*args, **kw),
+ type_inst,
+ res
+ )
+ # test that repr() copies out all arguments
+ print "mysql.%r" % type_inst
+ self.assert_compile(
+ eval("mysql.%r" % type_inst),
res
)
@@ -233,14 +240,22 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
(mysql.ENUM, ["foo", "bar"], {'unicode':True},
'''ENUM('foo','bar') UNICODE'''),
- (String, [20], {"collation":"utf8"}, 'VARCHAR(20) COLLATE utf8')
+ (String, [20], {"collation": "utf8"}, 'VARCHAR(20) COLLATE utf8')
]
for type_, args, kw, res in columns:
+ type_inst = type_(*args, **kw)
+ self.assert_compile(
+ type_inst,
+ res
+ )
+ # test that repr() copies out all arguments
self.assert_compile(
- type_(*args, **kw),
+ eval("mysql.%r" % type_inst)
+ if type_ is not String
+ else eval("%r" % type_inst),
res
)
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index a04f56ba4..dbc4716ef 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -967,6 +967,14 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
# depending on backend.
assert "('x'," in e.print_sql()
+ def test_repr(self):
+ e = Enum("x", "y", name="somename", convert_unicode=True,
+ quote=True, inherit_schema=True)
+ eq_(
+ repr(e),
+ "Enum('x', 'y', name='somename', inherit_schema=True)"
+ )
+
class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
__excluded_on__ = (
('mysql', '<', (4, 1, 1)), # screwy varbinary types