diff options
Diffstat (limited to 'kombu/utils/json.py')
-rw-r--r-- | kombu/utils/json.py | 159 |
1 files changed, 98 insertions, 61 deletions
diff --git a/kombu/utils/json.py b/kombu/utils/json.py index cedaa793..ec6269e2 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -1,75 +1,75 @@ """JSON Serialization Utilities.""" -import datetime -import decimal -import json as stdjson +from __future__ import annotations + +import base64 +import json import uuid +from datetime import date, datetime, time +from decimal import Decimal +from typing import Any, Callable, TypeVar -try: - from django.utils.functional import Promise as DjangoPromise -except ImportError: # pragma: no cover - class DjangoPromise: - """Dummy object.""" +textual_types = () try: - import json - _json_extra_kwargs = {} - - class _DecodeError(Exception): - pass -except ImportError: # pragma: no cover - import simplejson as json - from simplejson.decoder import JSONDecodeError as _DecodeError - _json_extra_kwargs = { - 'use_decimal': False, - 'namedtuple_as_object': False, - } - + from django.utils.functional import Promise -_encoder_cls = type(json._default_encoder) -_default_encoder = None # ... set to JSONEncoder below. + textual_types += (Promise,) +except ImportError: + pass -class JSONEncoder(_encoder_cls): +class JSONEncoder(json.JSONEncoder): """Kombu custom json encoder.""" - def default(self, o, - dates=(datetime.datetime, datetime.date), - times=(datetime.time,), - textual=(decimal.Decimal, uuid.UUID, DjangoPromise), - isinstance=isinstance, - datetime=datetime.datetime, - text_t=str): - reducer = getattr(o, '__json__', None) + def default(self, o): + reducer = getattr(o, "__json__", None) if reducer is not None: return reducer() - else: - if isinstance(o, dates): - if not isinstance(o, datetime): - o = datetime(o.year, o.month, o.day, 0, 0, 0, 0) - r = o.isoformat() - if r.endswith("+00:00"): - r = r[:-6] + "Z" - return r - elif isinstance(o, times): - return o.isoformat() - elif isinstance(o, textual): - return text_t(o) - return super().default(o) + if isinstance(o, textual_types): + return str(o) + + for t, (marker, encoder) in _encoders.items(): + if isinstance(o, t): + return _as(marker, encoder(o)) + + # Bytes is slightly trickier, so we cannot put them directly + # into _encoders, because we use two formats: bytes, and base64. + if isinstance(o, bytes): + try: + return _as("bytes", o.decode("utf-8")) + except UnicodeDecodeError: + return _as("base64", base64.b64encode(o).decode("utf-8")) -_default_encoder = JSONEncoder + return super().default(o) -def dumps(s, _dumps=json.dumps, cls=None, default_kwargs=None, **kwargs): +def _as(t: str, v: Any): + return {"__type__": t, "__value__": v} + + +def dumps( + s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs +): """Serialize object to json string.""" - if not default_kwargs: - default_kwargs = _json_extra_kwargs - return _dumps(s, cls=cls or _default_encoder, - **dict(default_kwargs, **kwargs)) + default_kwargs = default_kwargs or {} + return _dumps(s, cls=cls, **dict(default_kwargs, **kwargs)) + + +def object_hook(o: dict): + """Hook function to perform custom deserialization.""" + if o.keys() == {"__type__", "__value__"}: + decoder = _decoders.get(o["__type__"]) + if decoder: + return decoder(o["__value__"]) + else: + raise ValueError("Unsupported type", type, o) + else: + return o -def loads(s, _loads=json.loads, decode_bytes=True): +def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): """Deserialize json from string.""" # None of the json implementations supports decoding from # a buffer/memoryview, or even reading from a stream @@ -78,14 +78,51 @@ def loads(s, _loads=json.loads, decode_bytes=True): # over. Note that pickle does support buffer/memoryview # </rant> if isinstance(s, memoryview): - s = s.tobytes().decode('utf-8') + s = s.tobytes().decode("utf-8") elif isinstance(s, bytearray): - s = s.decode('utf-8') + s = s.decode("utf-8") elif decode_bytes and isinstance(s, bytes): - s = s.decode('utf-8') - - try: - return _loads(s) - except _DecodeError: - # catch "Unpaired high surrogate" error - return stdjson.loads(s) + s = s.decode("utf-8") + + return _loads(s, object_hook=object_hook) + + +DecoderT = EncoderT = Callable[[Any], Any] +T = TypeVar("T") +EncodedT = TypeVar("EncodedT") + + +def register_type( + t: type[T], + marker: str, + encoder: Callable[[T], EncodedT], + decoder: Callable[[EncodedT], T], +): + """Add support for serializing/deserializing native python type.""" + _encoders[t] = (marker, encoder) + _decoders[marker] = decoder + + +_encoders: dict[type, tuple[str, EncoderT]] = {} +_decoders: dict[str, DecoderT] = { + "bytes": lambda o: o.encode("utf-8"), + "base64": lambda o: base64.b64decode(o.encode("utf-8")), +} + +# NOTE: datetime should be registered before date, +# because datetime is also instance of date. +register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat) +register_type( + date, + "date", + lambda o: o.isoformat(), + lambda o: datetime.fromisoformat(o).date(), +) +register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) +register_type(Decimal, "decimal", str, Decimal) +register_type( + uuid.UUID, + "uuid", + lambda o: {"hex": o.hex, "version": o.version}, + lambda o: uuid.UUID(**o), +) |