|
41 | 41 | require,
|
42 | 42 | )
|
43 | 43 | from azul.json import (
|
| 44 | + PolymorphicSerializable, |
| 45 | + RegisteredPolymorphicSerializable, |
44 | 46 | Serializable,
|
45 | 47 | )
|
46 | 48 | from azul.types import (
|
@@ -483,6 +485,10 @@ def _metadata[V](self, key: str, default: V) -> V:
|
483 | 485 | except KeyError:
|
484 | 486 | return default
|
485 | 487 |
|
| 488 | + @cached_property |
| 489 | + def discriminator(self) -> str | None: |
| 490 | + return self._metadata('discriminator', None) |
| 491 | + |
486 | 492 | def handle(self, x: str) -> T:
|
487 | 493 | if self.custom is None:
|
488 | 494 | return self._handle(x, self._reify(self.field.type))
|
@@ -536,7 +542,12 @@ def _handle(self, x: str, field_type: Any):
|
536 | 542 | elif issubclass(field_type, Serializable):
|
537 | 543 | inner_cls_name = field_type.__name__
|
538 | 544 | self.globals[inner_cls_name] = field_type
|
539 |
| - return self._serializable(x, inner_cls_name) |
| 545 | + is_polymorphic = issubclass(field_type, PolymorphicSerializable) |
| 546 | + has_discriminator = self.discriminator is not None |
| 547 | + if is_polymorphic and has_discriminator: |
| 548 | + return self._polymorphic(x, inner_cls_name) |
| 549 | + else: |
| 550 | + return self._serializable(x, inner_cls_name) |
540 | 551 | else:
|
541 | 552 | origin = get_origin(field_type)
|
542 | 553 | if origin in (Union, UnionType):
|
@@ -575,6 +586,10 @@ def _optional(self, x: str, field_type: type) -> T:
|
575 | 586 | def _serializable(self, x: str, inner_cls_name: str) -> T:
|
576 | 587 | raise NotImplementedError
|
577 | 588 |
|
| 589 | + @abstractmethod |
| 590 | + def _polymorphic(self, x: str, inner_cls_name: str) -> T: |
| 591 | + raise NotImplementedError |
| 592 | + |
578 | 593 | @abstractmethod
|
579 | 594 | def _list(self, x: str, item_type: type) -> T:
|
580 | 595 | raise NotImplementedError
|
@@ -603,6 +618,15 @@ def _serializable(self, x: str, inner_cls_name: str) -> Source:
|
603 | 618 | f'{x} = {inner_cls_name}.from_json({x})'
|
604 | 619 | ]
|
605 | 620 |
|
| 621 | + def _polymorphic(self, x: str, inner_cls_name: str) -> Source: |
| 622 | + depth = next(self.depth) |
| 623 | + cls = f'cls{depth}' |
| 624 | + return [ |
| 625 | + f'{cls} = {x}["{self.discriminator}"]', |
| 626 | + f'{cls} = {inner_cls_name}.cls_from_json({cls})', |
| 627 | + f'{x} = {cls}.from_json({x})' |
| 628 | + ] |
| 629 | + |
606 | 630 | def _primitive(self, x: str, field_type: type) -> Source:
|
607 | 631 | return [
|
608 | 632 | f'if not isinstance({x}, {field_type.__name__}):', [
|
@@ -670,6 +694,9 @@ def _optional(self, x: str, field_type: type) -> str:
|
670 | 694 | def _serializable(self, x: str, inner_cls_name: str) -> str:
|
671 | 695 | return f'{x}.to_json()'
|
672 | 696 |
|
| 697 | + def _polymorphic(self, x: str, inner_cls_name: str) -> str: |
| 698 | + return f'dict({x}.to_json(), {self.discriminator}={x}.cls_to_json())' |
| 699 | + |
673 | 700 | def _list(self, x: str, item_type: type) -> str:
|
674 | 701 | depth = next(self.depth)
|
675 | 702 | v = f'v{depth}'
|
@@ -740,3 +767,93 @@ def _set_field_metadata[T: attrs.Attribute](field: T | None, key, value):
|
740 | 767 | metadata = field.metadata.setdefault('azul', {})
|
741 | 768 | metadata[key] = value
|
742 | 769 | return field
|
| 770 | + |
| 771 | + |
| 772 | +def polymorphic[T: attrs.Attribute](field: T | None = None, |
| 773 | + *, |
| 774 | + discriminator: str |
| 775 | + ) -> T: |
| 776 | + """ |
| 777 | + Mark an attrs field to use the given name for the discriminator property in |
| 778 | + serialized instances of PolymorphicSerializable that occur in the value of |
| 779 | + that field. The given discriminator property of a serialized instance |
| 780 | + represents the type to use when deserializing that instance again. |
| 781 | +
|
| 782 | + >>> class Inner(SerializableAttrs, RegisteredPolymorphicSerializable): |
| 783 | + ... pass |
| 784 | +
|
| 785 | + >>> @attrs.frozen |
| 786 | + ... class InnerWithInt(Inner): |
| 787 | + ... x: int |
| 788 | +
|
| 789 | + >>> @attrs.frozen |
| 790 | + ... class InnerWithStr(Inner): |
| 791 | + ... y: str |
| 792 | +
|
| 793 | + >>> @attrs.frozen(kw_only=True) |
| 794 | + ... class Outer(SerializableAttrs): |
| 795 | + ... inner: Inner = polymorphic(discriminator='type') |
| 796 | + ... inners: list[Inner] = polymorphic(discriminator='_cls') |
| 797 | +
|
| 798 | + >>> from azul.doctests import assert_json |
| 799 | +
|
| 800 | + >>> outer = Outer(inner=InnerWithInt(42), |
| 801 | + ... inners=[InnerWithStr('foo'), InnerWithInt(7)]) |
| 802 | + >>> assert_json(outer.to_json()) |
| 803 | + { |
| 804 | + "inner": { |
| 805 | + "x": 42, |
| 806 | + "type": "InnerWithInt" |
| 807 | + }, |
| 808 | + "inners": [ |
| 809 | + { |
| 810 | + "y": "foo", |
| 811 | + "_cls": "InnerWithStr" |
| 812 | + }, |
| 813 | + { |
| 814 | + "x": 7, |
| 815 | + "_cls": "InnerWithInt" |
| 816 | + } |
| 817 | + ] |
| 818 | + } |
| 819 | + >>> Outer.from_json(outer.to_json()) == outer |
| 820 | + True |
| 821 | +
|
| 822 | + In order to enable polymorphic serialization of the value of a given field, |
| 823 | + the discriminator property needs to be specified explicitly, otherwise the |
| 824 | + serialization framework will resort to the static type of the field. |
| 825 | +
|
| 826 | + >>> @attrs.frozen |
| 827 | + ... class GenericOuter[T: Inner](SerializableAttrs): |
| 828 | + ... inner: T |
| 829 | +
|
| 830 | + >>> class StaticOuter(GenericOuter[InnerWithInt]): |
| 831 | + ... pass |
| 832 | +
|
| 833 | + >>> outer = StaticOuter(InnerWithInt(42)) |
| 834 | + >>> outer.to_json() |
| 835 | + {'inner': {'x': 42}} |
| 836 | +
|
| 837 | + Despite the fact that ``{'x': 42}`` does not encode any type information, |
| 838 | + ``from_json`` can tell from the static type of the field that {'x': 42} |
| 839 | + should be deserialized as an ``InnerWithInt``. |
| 840 | +
|
| 841 | + >>> StaticOuter.from_json(outer.to_json()).inner |
| 842 | + InnerWithInt(x=42) |
| 843 | +
|
| 844 | + >>> StaticOuter.from_json(outer.to_json()) == outer |
| 845 | + True |
| 846 | +
|
| 847 | + However, when the static type of the field is not concrete, deserialization |
| 848 | + may fail or, like in this case, lose information by creating an instance of |
| 849 | + the parent class instead of the class that was serialized. |
| 850 | +
|
| 851 | + >>> @attrs.frozen |
| 852 | + ... class AbstractOuter(SerializableAttrs): |
| 853 | + ... inner: Inner |
| 854 | +
|
| 855 | + >>> outer = AbstractOuter(InnerWithInt(42)) |
| 856 | + >>> AbstractOuter.from_json(outer.to_json()).inner # doctest: +ELLIPSIS |
| 857 | + <azul.attrs.Inner object at ...> |
| 858 | + """ |
| 859 | + return _set_field_metadata(field, 'discriminator', discriminator) |
0 commit comments