From d850e56dd0e2d392bb6ead388c0f0a59f68e1bd2 Mon Sep 17 00:00:00 2001 From: Alexey Popravka Date: Mon, 24 Mar 2014 15:31:06 +0200 Subject: Unpacker's ext_hook fixed + tests --- test/test_unpack.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_unpack.py b/test/test_unpack.py index 544cebf..275f124 100644 --- a/test/test_unpack.py +++ b/test/test_unpack.py @@ -1,6 +1,6 @@ from io import BytesIO import sys -from msgpack import Unpacker, packb, OutOfData +from msgpack import Unpacker, packb, OutOfData, ExtType from pytest import raises, mark @@ -42,6 +42,29 @@ def test_unpacker_hook_refcnt(): assert sys.getrefcount(hook) == basecnt +def test_unpacker_ext_hook(): + + class MyUnpacker(Unpacker): + + def __init__(self): + super().__init__(ext_hook=self._hook, encoding='utf-8') + + def _hook(self, code, data): + if code == 1: + return int(data) + else: + return ExtType(code, data) + + unpacker = MyUnpacker() + unpacker.feed(packb({'a': 1}, encoding='utf-8')) + assert unpacker.unpack() == {'a': 1} + unpacker.feed(packb({'a': ExtType(1, b'123')}, encoding='utf-8')) + assert unpacker.unpack() == {'a': 123} + unpacker.feed(packb({'a': ExtType(2, b'321')}, encoding='utf-8')) + assert unpacker.unpack() == {'a': ExtType(2, b'321')} + + if __name__ == '__main__': test_unpack_array_header_from_file() test_unpacker_hook_refcnt() + test_unpacker_ext_hook() -- cgit v1.2.1 From ee38505db59c55d2d96516274030eed712028039 Mon Sep 17 00:00:00 2001 From: Alexey Popravka Date: Mon, 24 Mar 2014 15:42:16 +0200 Subject: fixed super() for python2 --- test/test_unpack.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_unpack.py b/test/test_unpack.py index 275f124..8d0d949 100644 --- a/test/test_unpack.py +++ b/test/test_unpack.py @@ -47,7 +47,8 @@ def test_unpacker_ext_hook(): class MyUnpacker(Unpacker): def __init__(self): - super().__init__(ext_hook=self._hook, encoding='utf-8') + super(MyUnpacker, self).__init__(ext_hook=self._hook, + encoding='utf-8') def _hook(self, code, data): if code == 1: -- cgit v1.2.1