Skip to content

Commit 971de5e

Browse files
committed
Polymorphic (de)serializastion support (#7104)
1 parent a7afbb9 commit 971de5e

File tree

2 files changed

+180
-1
lines changed

2 files changed

+180
-1
lines changed

src/azul/attrs.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
require,
4242
)
4343
from azul.json import (
44+
PolymorphicSerializable,
4445
Serializable,
4546
)
4647
from azul.types import (
@@ -483,6 +484,10 @@ def _metadata[V](self, key: str, default: V) -> V:
483484
except KeyError:
484485
return default
485486

487+
@cached_property
488+
def discriminator(self) -> str | None:
489+
return self._metadata('discriminator', None)
490+
486491
def handle(self, x: str) -> T:
487492
if self.custom is None:
488493
return self._handle(x, self._reify(self.field.type))
@@ -536,7 +541,12 @@ def _handle(self, x: str, field_type: Any):
536541
elif issubclass(field_type, Serializable):
537542
inner_cls_name = field_type.__name__
538543
self.globals[inner_cls_name] = field_type
539-
return self._serializable(x, inner_cls_name)
544+
is_polymorphic = issubclass(field_type, PolymorphicSerializable)
545+
has_discriminator = self.discriminator is not None
546+
if is_polymorphic and has_discriminator:
547+
return self._polymorphic(x, inner_cls_name)
548+
else:
549+
return self._serializable(x, inner_cls_name)
540550
else:
541551
origin = get_origin(field_type)
542552
if origin in (Union, UnionType):
@@ -575,6 +585,10 @@ def _optional(self, x: str, field_type: type) -> T:
575585
def _serializable(self, x: str, inner_cls_name: str) -> T:
576586
raise NotImplementedError
577587

588+
@abstractmethod
589+
def _polymorphic(self, x: str, inner_cls_name: str) -> T:
590+
raise NotImplementedError
591+
578592
@abstractmethod
579593
def _list(self, x: str, item_type: type) -> T:
580594
raise NotImplementedError
@@ -603,6 +617,15 @@ def _serializable(self, x: str, inner_cls_name: str) -> Source:
603617
f'{x} = {inner_cls_name}.from_json({x})'
604618
]
605619

620+
def _polymorphic(self, x: str, inner_cls_name: str) -> Source:
621+
depth = next(self.depth)
622+
cls = f'cls{depth}'
623+
return [
624+
f'{cls} = {x}["{self.discriminator}"]',
625+
f'{cls} = {inner_cls_name}.cls_from_json({cls})',
626+
f'{x} = {cls}.from_json({x})'
627+
]
628+
606629
def _primitive(self, x: str, field_type: type) -> Source:
607630
return [
608631
f'if not isinstance({x}, {field_type.__name__}):', [
@@ -670,6 +693,9 @@ def _optional(self, x: str, field_type: type) -> str:
670693
def _serializable(self, x: str, inner_cls_name: str) -> str:
671694
return f'{x}.to_json()'
672695

696+
def _polymorphic(self, x: str, inner_cls_name: str) -> str:
697+
return f'dict({x}.to_json(), {self.discriminator}={x}.cls_to_json())'
698+
673699
def _list(self, x: str, item_type: type) -> str:
674700
depth = next(self.depth)
675701
v = f'v{depth}'
@@ -740,3 +766,95 @@ def _set_field_metadata[T: attrs.Attribute](field: T | None, key, value):
740766
metadata = field.metadata.setdefault('azul', {})
741767
metadata[key] = value
742768
return field
769+
770+
771+
def polymorphic[T: attrs.Attribute](field: T | None = None,
772+
*,
773+
discriminator: str
774+
) -> T:
775+
"""
776+
Mark an attrs field to use the given name for the discriminator property in
777+
serialized instances of PolymorphicSerializable that occur in the value of
778+
that field. The given discriminator property of a serialized instance
779+
represents the type to use when deserializing that instance again.
780+
781+
>>> from azul.json import RegisteredPolymorphicSerializable
782+
783+
>>> class Inner(SerializableAttrs, RegisteredPolymorphicSerializable):
784+
... pass
785+
786+
>>> @attrs.frozen
787+
... class InnerWithInt(Inner):
788+
... x: int
789+
790+
>>> @attrs.frozen
791+
... class InnerWithStr(Inner):
792+
... y: str
793+
794+
>>> @attrs.frozen(kw_only=True)
795+
... class Outer(SerializableAttrs):
796+
... inner: Inner = polymorphic(discriminator='type')
797+
... inners: list[Inner] = polymorphic(discriminator='_cls')
798+
799+
>>> from azul.doctests import assert_json
800+
801+
>>> outer = Outer(inner=InnerWithInt(42),
802+
... inners=[InnerWithStr('foo'), InnerWithInt(7)])
803+
>>> assert_json(outer.to_json())
804+
{
805+
"inner": {
806+
"x": 42,
807+
"type": "InnerWithInt"
808+
},
809+
"inners": [
810+
{
811+
"y": "foo",
812+
"_cls": "InnerWithStr"
813+
},
814+
{
815+
"x": 7,
816+
"_cls": "InnerWithInt"
817+
}
818+
]
819+
}
820+
>>> Outer.from_json(outer.to_json()) == outer
821+
True
822+
823+
In order to enable polymorphic serialization of the value of a given field,
824+
the discriminator property needs to be specified explicitly, otherwise the
825+
serialization framework will resort to the static type of the field.
826+
827+
>>> @attrs.frozen
828+
... class GenericOuter[T: Inner](SerializableAttrs):
829+
... inner: T
830+
831+
>>> class StaticOuter(GenericOuter[InnerWithInt]):
832+
... pass
833+
834+
>>> outer = StaticOuter(InnerWithInt(42))
835+
>>> outer.to_json()
836+
{'inner': {'x': 42}}
837+
838+
Despite the fact that ``{'x': 42}`` does not encode any type information,
839+
``from_json`` can tell from the static type of the field that {'x': 42}
840+
should be deserialized as an ``InnerWithInt``.
841+
842+
>>> StaticOuter.from_json(outer.to_json()).inner
843+
InnerWithInt(x=42)
844+
845+
>>> StaticOuter.from_json(outer.to_json()) == outer
846+
True
847+
848+
However, when the static type of the field is not concrete, deserialization
849+
may fail or, like in this case, lose information by creating an instance of
850+
the parent class instead of the class that was serialized.
851+
852+
>>> @attrs.frozen
853+
... class AbstractOuter(SerializableAttrs):
854+
... inner: Inner
855+
856+
>>> outer = AbstractOuter(InnerWithInt(42))
857+
>>> AbstractOuter.from_json(outer.to_json()).inner # doctest: +ELLIPSIS
858+
<azul.attrs.Inner object at ...>
859+
"""
860+
return _set_field_metadata(field, 'discriminator', discriminator)

src/azul/json.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,67 @@ def to_json(self) -> AnyJSON:
267267
raise NotImplementedError
268268

269269

270+
class PolymorphicSerializable(Serializable):
271+
"""
272+
A class whose subclasses' instances can be transformed to and from JSON
273+
while retaining the concrete type of said instances.
274+
"""
275+
276+
@classmethod
277+
def cls_to_json(cls) -> AnyJSON:
278+
"""
279+
Serialize the given type to JSON.
280+
"""
281+
raise NotImplementedError
282+
283+
@classmethod
284+
def cls_from_json(cls, json: AnyJSON) -> type[Self]:
285+
"""
286+
Deserialize a subtype of the given type from the given JSON.
287+
"""
288+
raise NotImplementedError
289+
290+
291+
class RegisteredPolymorphicSerializable(PolymorphicSerializable):
292+
"""
293+
A polymorphically serializable class that tracks its subclasses in a
294+
registry and uses their name to discriminate serialized instances. It
295+
requires every subclass to be registered before instances of that subclass
296+
can be (de)serialized. It also requires the name of each subclass to be
297+
unique, regardless of the module the subclass is defined in.
298+
"""
299+
300+
_registry: dict[str, type[Self]] = {}
301+
302+
@classmethod
303+
def cls_to_json(cls) -> AnyJSON:
304+
assert cls._registry[cls.__name__] == cls
305+
return cls.__name__
306+
307+
@classmethod
308+
def cls_from_json(cls, json: AnyJSON) -> type[Self]:
309+
return cls._registry[json_str(json)]
310+
311+
def __init_subclass__(cls):
312+
super().__init_subclass__()
313+
try:
314+
other_cls = cls._registry[cls.__name__]
315+
except KeyError:
316+
pass
317+
else:
318+
# For attrs classes, this hook is invoked twice: once for the
319+
# original class and once for the attrs-generated replacement. These
320+
# are two different objects, so they are neither the same nor equal
321+
# so it is difficult to tell wether we're dealing with the attrs
322+
# replacement or a genuine collision. Both original and replacement
323+
# reference the same containing module, so we assume that two
324+
# classes of the same name from the same module indicate that attrs
325+
# is involved and does not constitue a collision.
326+
assert other_cls.__module__ == cls.__module__, R(
327+
'Class name collision', cls, other_cls)
328+
cls._registry[cls.__name__] = cls
329+
330+
270331
class Parseable(Serializable):
271332
"""
272333
A class whose instances have a string representation that can be used in

0 commit comments

Comments
 (0)