summaryrefslogtreecommitdiff
path: root/kombu/utils/json.py
diff options
context:
space:
mode:
Diffstat (limited to 'kombu/utils/json.py')
-rw-r--r--kombu/utils/json.py159
1 files changed, 98 insertions, 61 deletions
diff --git a/kombu/utils/json.py b/kombu/utils/json.py
index cedaa793..ec6269e2 100644
--- a/kombu/utils/json.py
+++ b/kombu/utils/json.py
@@ -1,75 +1,75 @@
"""JSON Serialization Utilities."""
-import datetime
-import decimal
-import json as stdjson
+from __future__ import annotations
+
+import base64
+import json
import uuid
+from datetime import date, datetime, time
+from decimal import Decimal
+from typing import Any, Callable, TypeVar
-try:
- from django.utils.functional import Promise as DjangoPromise
-except ImportError: # pragma: no cover
- class DjangoPromise:
- """Dummy object."""
+textual_types = ()
try:
- import json
- _json_extra_kwargs = {}
-
- class _DecodeError(Exception):
- pass
-except ImportError: # pragma: no cover
- import simplejson as json
- from simplejson.decoder import JSONDecodeError as _DecodeError
- _json_extra_kwargs = {
- 'use_decimal': False,
- 'namedtuple_as_object': False,
- }
-
+ from django.utils.functional import Promise
-_encoder_cls = type(json._default_encoder)
-_default_encoder = None # ... set to JSONEncoder below.
+ textual_types += (Promise,)
+except ImportError:
+ pass
-class JSONEncoder(_encoder_cls):
+class JSONEncoder(json.JSONEncoder):
"""Kombu custom json encoder."""
- def default(self, o,
- dates=(datetime.datetime, datetime.date),
- times=(datetime.time,),
- textual=(decimal.Decimal, uuid.UUID, DjangoPromise),
- isinstance=isinstance,
- datetime=datetime.datetime,
- text_t=str):
- reducer = getattr(o, '__json__', None)
+ def default(self, o):
+ reducer = getattr(o, "__json__", None)
if reducer is not None:
return reducer()
- else:
- if isinstance(o, dates):
- if not isinstance(o, datetime):
- o = datetime(o.year, o.month, o.day, 0, 0, 0, 0)
- r = o.isoformat()
- if r.endswith("+00:00"):
- r = r[:-6] + "Z"
- return r
- elif isinstance(o, times):
- return o.isoformat()
- elif isinstance(o, textual):
- return text_t(o)
- return super().default(o)
+ if isinstance(o, textual_types):
+ return str(o)
+
+ for t, (marker, encoder) in _encoders.items():
+ if isinstance(o, t):
+ return _as(marker, encoder(o))
+
+ # Bytes is slightly trickier, so we cannot put them directly
+ # into _encoders, because we use two formats: bytes, and base64.
+ if isinstance(o, bytes):
+ try:
+ return _as("bytes", o.decode("utf-8"))
+ except UnicodeDecodeError:
+ return _as("base64", base64.b64encode(o).decode("utf-8"))
-_default_encoder = JSONEncoder
+ return super().default(o)
-def dumps(s, _dumps=json.dumps, cls=None, default_kwargs=None, **kwargs):
+def _as(t: str, v: Any):
+ return {"__type__": t, "__value__": v}
+
+
+def dumps(
+ s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs
+):
"""Serialize object to json string."""
- if not default_kwargs:
- default_kwargs = _json_extra_kwargs
- return _dumps(s, cls=cls or _default_encoder,
- **dict(default_kwargs, **kwargs))
+ default_kwargs = default_kwargs or {}
+ return _dumps(s, cls=cls, **dict(default_kwargs, **kwargs))
+
+
+def object_hook(o: dict):
+ """Hook function to perform custom deserialization."""
+ if o.keys() == {"__type__", "__value__"}:
+ decoder = _decoders.get(o["__type__"])
+ if decoder:
+ return decoder(o["__value__"])
+ else:
+ raise ValueError("Unsupported type", type, o)
+ else:
+ return o
-def loads(s, _loads=json.loads, decode_bytes=True):
+def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
"""Deserialize json from string."""
# None of the json implementations supports decoding from
# a buffer/memoryview, or even reading from a stream
@@ -78,14 +78,51 @@ def loads(s, _loads=json.loads, decode_bytes=True):
# over. Note that pickle does support buffer/memoryview
# </rant>
if isinstance(s, memoryview):
- s = s.tobytes().decode('utf-8')
+ s = s.tobytes().decode("utf-8")
elif isinstance(s, bytearray):
- s = s.decode('utf-8')
+ s = s.decode("utf-8")
elif decode_bytes and isinstance(s, bytes):
- s = s.decode('utf-8')
-
- try:
- return _loads(s)
- except _DecodeError:
- # catch "Unpaired high surrogate" error
- return stdjson.loads(s)
+ s = s.decode("utf-8")
+
+ return _loads(s, object_hook=object_hook)
+
+
+DecoderT = EncoderT = Callable[[Any], Any]
+T = TypeVar("T")
+EncodedT = TypeVar("EncodedT")
+
+
+def register_type(
+ t: type[T],
+ marker: str,
+ encoder: Callable[[T], EncodedT],
+ decoder: Callable[[EncodedT], T],
+):
+ """Add support for serializing/deserializing native python type."""
+ _encoders[t] = (marker, encoder)
+ _decoders[marker] = decoder
+
+
+_encoders: dict[type, tuple[str, EncoderT]] = {}
+_decoders: dict[str, DecoderT] = {
+ "bytes": lambda o: o.encode("utf-8"),
+ "base64": lambda o: base64.b64decode(o.encode("utf-8")),
+}
+
+# NOTE: datetime should be registered before date,
+# because datetime is also instance of date.
+register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat)
+register_type(
+ date,
+ "date",
+ lambda o: o.isoformat(),
+ lambda o: datetime.fromisoformat(o).date(),
+)
+register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
+register_type(Decimal, "decimal", str, Decimal)
+register_type(
+ uuid.UUID,
+ "uuid",
+ lambda o: {"hex": o.hex, "version": o.version},
+ lambda o: uuid.UUID(**o),
+)