Skip to content

Commit 9fc7fc2

Browse files
committed
Polymorphic (de)serializastion support (#7104)
1 parent 2fa130b commit 9fc7fc2

File tree

2 files changed

+178
-1
lines changed

2 files changed

+178
-1
lines changed

src/azul/attrs.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
require,
4242
)
4343
from azul.json import (
44+
PolymorphicSerializable,
45+
RegisteredPolymorphicSerializable,
4446
Serializable,
4547
)
4648
from azul.types import (
@@ -483,6 +485,10 @@ def _metadata[V](self, key: str, default: V) -> V:
483485
except KeyError:
484486
return default
485487

488+
@cached_property
489+
def discriminator(self) -> str | None:
490+
return self._metadata('discriminator', None)
491+
486492
def handle(self, x: str) -> T:
487493
if self.custom is None:
488494
return self._handle(x, self._reify(self.field.type))
@@ -536,7 +542,12 @@ def _handle(self, x: str, field_type: Any):
536542
elif issubclass(field_type, Serializable):
537543
inner_cls_name = field_type.__name__
538544
self.globals[inner_cls_name] = field_type
539-
return self._serializable(x, inner_cls_name)
545+
is_polymorphic = issubclass(field_type, PolymorphicSerializable)
546+
has_discriminator = self.discriminator is not None
547+
if is_polymorphic and has_discriminator:
548+
return self._polymorphic(x, inner_cls_name)
549+
else:
550+
return self._serializable(x, inner_cls_name)
540551
else:
541552
origin = get_origin(field_type)
542553
if origin in (Union, UnionType):
@@ -575,6 +586,10 @@ def _optional(self, x: str, field_type: type) -> T:
575586
def _serializable(self, x: str, inner_cls_name: str) -> T:
576587
raise NotImplementedError
577588

589+
@abstractmethod
590+
def _polymorphic(self, x: str, inner_cls_name: str) -> T:
591+
raise NotImplementedError
592+
578593
@abstractmethod
579594
def _list(self, x: str, item_type: type) -> T:
580595
raise NotImplementedError
@@ -603,6 +618,15 @@ def _serializable(self, x: str, inner_cls_name: str) -> Source:
603618
f'{x} = {inner_cls_name}.from_json({x})'
604619
]
605620

621+
def _polymorphic(self, x: str, inner_cls_name: str) -> Source:
622+
depth = next(self.depth)
623+
cls = f'cls{depth}'
624+
return [
625+
f'{cls} = {x}["{self.discriminator}"]',
626+
f'{cls} = {inner_cls_name}.cls_from_json({cls})',
627+
f'{x} = {cls}.from_json({x})'
628+
]
629+
606630
def _primitive(self, x: str, field_type: type) -> Source:
607631
return [
608632
f'if not isinstance({x}, {field_type.__name__}):', [
@@ -670,6 +694,9 @@ def _optional(self, x: str, field_type: type) -> str:
670694
def _serializable(self, x: str, inner_cls_name: str) -> str:
671695
return f'{x}.to_json()'
672696

697+
def _polymorphic(self, x: str, inner_cls_name: str) -> str:
698+
return f'dict({x}.to_json(), {self.discriminator}={x}.cls_to_json())'
699+
673700
def _list(self, x: str, item_type: type) -> str:
674701
depth = next(self.depth)
675702
v = f'v{depth}'
@@ -740,3 +767,93 @@ def _set_field_metadata[T: attrs.Attribute](field: T | None, key, value):
740767
metadata = field.metadata.setdefault('azul', {})
741768
metadata[key] = value
742769
return field
770+
771+
772+
def polymorphic[T: attrs.Attribute](field: T | None = None,
773+
*,
774+
discriminator: str
775+
) -> T:
776+
"""
777+
Mark an attrs field to use the given name for the discriminator property in
778+
serialized instances of PolymorphicSerializable that occur in the value of
779+
that field. The given discriminator property of a serialized instance
780+
represents the type to use when deserializing that instance again.
781+
782+
>>> class Inner(SerializableAttrs, RegisteredPolymorphicSerializable):
783+
... pass
784+
785+
>>> @attrs.frozen
786+
... class InnerWithInt(Inner):
787+
... x: int
788+
789+
>>> @attrs.frozen
790+
... class InnerWithStr(Inner):
791+
... y: str
792+
793+
>>> @attrs.frozen(kw_only=True)
794+
... class Outer(SerializableAttrs):
795+
... inner: Inner = polymorphic(discriminator='type')
796+
... inners: list[Inner] = polymorphic(discriminator='_cls')
797+
798+
>>> from azul.doctests import assert_json
799+
800+
>>> outer = Outer(inner=InnerWithInt(42),
801+
... inners=[InnerWithStr('foo'), InnerWithInt(7)])
802+
>>> assert_json(outer.to_json())
803+
{
804+
"inner": {
805+
"x": 42,
806+
"type": "InnerWithInt"
807+
},
808+
"inners": [
809+
{
810+
"y": "foo",
811+
"_cls": "InnerWithStr"
812+
},
813+
{
814+
"x": 7,
815+
"_cls": "InnerWithInt"
816+
}
817+
]
818+
}
819+
>>> Outer.from_json(outer.to_json()) == outer
820+
True
821+
822+
In order to enable polymorphic serialization of the value of a given field,
823+
the discriminator property needs to be specified explicitly, otherwise the
824+
serialization framework will resort to the static type of the field.
825+
826+
>>> @attrs.frozen
827+
... class GenericOuter[T: Inner](SerializableAttrs):
828+
... inner: T
829+
830+
>>> class StaticOuter(GenericOuter[InnerWithInt]):
831+
... pass
832+
833+
>>> outer = StaticOuter(InnerWithInt(42))
834+
>>> outer.to_json()
835+
{'inner': {'x': 42}}
836+
837+
Despite the fact that ``{'x': 42}`` does not encode any type information,
838+
``from_json`` can tell from the static type of the field that {'x': 42}
839+
should be deserialized as an ``InnerWithInt``.
840+
841+
>>> StaticOuter.from_json(outer.to_json()).inner
842+
InnerWithInt(x=42)
843+
844+
>>> StaticOuter.from_json(outer.to_json()) == outer
845+
True
846+
847+
However, when the static type of the field is not concrete, deserialization
848+
may fail or, like in this case, lose information by creating an instance of
849+
the parent class instead of the class that was serialized.
850+
851+
>>> @attrs.frozen
852+
... class AbstractOuter(SerializableAttrs):
853+
... inner: Inner
854+
855+
>>> outer = AbstractOuter(InnerWithInt(42))
856+
>>> AbstractOuter.from_json(outer.to_json()).inner # doctest: +ELLIPSIS
857+
<azul.attrs.Inner object at ...>
858+
"""
859+
return _set_field_metadata(field, 'discriminator', discriminator)

src/azul/json.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,66 @@ def to_json(self) -> AnyJSON:
267267
raise NotImplementedError
268268

269269

270+
class PolymorphicSerializable(Serializable):
271+
"""
272+
A class whose subclasses' instances can be transformed to and from JSON
273+
while retaining the concrete type of said instances.
274+
"""
275+
@classmethod
276+
def cls_to_json(cls) -> AnyJSON:
277+
"""
278+
Serialize the given type to JSON.
279+
"""
280+
raise NotImplementedError
281+
282+
@classmethod
283+
def cls_from_json(cls, json: AnyJSON) -> type[Self]:
284+
"""
285+
Deserialize a subtype of the given type from the given JSON.
286+
"""
287+
raise NotImplementedError
288+
289+
290+
class RegisteredPolymorphicSerializable(PolymorphicSerializable):
291+
"""
292+
A polymorphically serializable class that tracks its subclasses in a
293+
registry and uses their name to discriminate serialized instances. It
294+
requires every subclass to be registered before instances of that subclass
295+
can be (de)serialized. It also requires the name of each subclass to be
296+
unique, regardless of the module the subclass is defined in.
297+
"""
298+
299+
_registry: dict[str, type[Self]] = {}
300+
301+
@classmethod
302+
def cls_to_json(cls) -> AnyJSON:
303+
assert cls._registry[cls.__name__] == cls
304+
return cls.__name__
305+
306+
@classmethod
307+
def cls_from_json(cls, json: AnyJSON) -> type[Self]:
308+
return cls._registry[json_str(json)]
309+
310+
def __init_subclass__(cls):
311+
super().__init_subclass__()
312+
try:
313+
other_cls = cls._registry[cls.__name__]
314+
except KeyError:
315+
pass
316+
else:
317+
# For attrs classes, this hook is invoked twice: once for the
318+
# original class and once for the attrs-generated replacement. These
319+
# are two different objects, so they are neither the same nor equal
320+
# so it is difficult to tell wether we're dealing with the attrs
321+
# replacement or a genuine collision. Both original and replacement
322+
# reference the same containing module, so we assume that two
323+
# classes of the same name from the same module indicate that attrs
324+
# is involved and does not constitue a collision.
325+
assert other_cls.__module__ == cls.__module__, R(
326+
'Class name collision', cls, other_cls)
327+
cls._registry[cls.__name__] = cls
328+
329+
270330
class Parseable(Serializable):
271331
"""
272332
A class whose instances have a string representation that can be used in

0 commit comments

Comments
 (0)