Skip to content

Commit 2fa130b

Browse files
committed
Improve handling of custom field (de)serializer
1 parent 69556c9 commit 2fa130b

File tree

1 file changed

+68
-41
lines changed

1 file changed

+68
-41
lines changed

src/azul/attrs.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Any,
1414
Callable,
1515
Iterator,
16-
Literal,
1716
Optional,
1817
Self,
1918
Tuple,
@@ -37,6 +36,7 @@
3736

3837
from azul import (
3938
R,
39+
cached_property,
4040
config,
4141
require,
4242
)
@@ -177,6 +177,9 @@ def validator(_instance, field, value):
177177

178178
type Source = list[str | tuple[str, ...] | Source]
179179

180+
type FromJSON = Callable[[AnyJSON], Any]
181+
type ToJSON = Callable[[Any], AnyJSON]
182+
180183

181184
class SerializableAttrs(Serializable, attrs.AttrsInstance):
182185
"""
@@ -293,10 +296,6 @@ def _assert_concrete(cls):
293296
assert not cls._deferred_fields, R(
294297
'Class has fields of unknown type', cls._deferred_fields)
295298

296-
class Metadata(TypedDict):
297-
from_json: Callable[[AnyJSON], Any] | None
298-
to_json: Callable[[Any], AnyJSON] | None
299-
300299
def __init_subclass__(cls):
301300
super().__init_subclass__()
302301
try:
@@ -365,31 +364,22 @@ def _make(cls, fields: list[attrs.Attribute]) -> frozenset[str]:
365364
cls._define(to_json)
366365
return deferred_fields
367366

368-
@classmethod
369-
def _serializable(cls,
370-
field: attrs.Attribute,
371-
key: Literal['from_json', 'to_json']
372-
) -> bool:
373-
try:
374-
return field.metadata['azul'][key] is not None
375-
except KeyError:
376-
return True
377-
378367
@classmethod
379368
def _make_from_json(cls, fields: list[attrs.Attribute]) -> Callable:
380369
globals = {cls.__name__: cls}
370+
deserializers = (cls.Deserializer(cls, field, globals) for field in fields)
381371
source = cls._indent([
382372
'@classmethod',
383373
'def _from_json(cls, json):', [
384374
f'kwargs = super({cls.__name__}, cls)._from_json(json)',
385375
*flatten(
386376
[
387-
f'x = json["{field.name}"]',
388-
*(cls.Deserializer(cls, field, globals).handle('x')),
389-
f'kwargs["{field.name}"] = x'
377+
f'x = json["{deserializer.field.name}"]',
378+
*(deserializer.handle('x')),
379+
f'kwargs["{deserializer.field.name}"] = x'
390380
]
391-
for field in fields
392-
if cls._serializable(field, 'from_json')
381+
for deserializer in deserializers
382+
if deserializer.enabled
393383
),
394384
'return kwargs'
395385
]
@@ -399,6 +389,7 @@ def _make_from_json(cls, fields: list[attrs.Attribute]) -> Callable:
399389
@classmethod
400390
def _make_to_json(cls, fields: list[attrs.Attribute]) -> Callable:
401391
globals = {cls.__name__: cls}
392+
serializers = (cls.Serializer(cls, field, globals) for field in fields)
402393
to_json = cls._indent([
403394
'def to_json(self):', [
404395
# Using the super() shortcut would require messing with the
@@ -407,11 +398,11 @@ def _make_to_json(cls, fields: list[attrs.Attribute]) -> Callable:
407398
f'json = super({cls.__name__}, self).to_json()',
408399
*flatten(
409400
[
410-
f'x = self.{field.name}',
411-
f'json["{field.name}"] = ' + cls.Serializer(cls, field, globals).handle('x')
401+
f'x = self.{serializer.field.name}',
402+
f'json["{serializer.field.name}"] = ' + serializer.handle('x')
412403
]
413-
for field in fields
414-
if cls._serializable(field, 'to_json')
404+
for serializer in serializers
405+
if serializer.enabled
415406
),
416407
'return json'
417408
]
@@ -478,13 +469,25 @@ class Strategy[T](metaclass=ABCMeta):
478469
class MustDefer(Exception):
479470
pass
480471

481-
def handle(self, x: str) -> T:
472+
class Custom(TypedDict):
473+
from_json: FromJSON | None
474+
to_json: ToJSON | None
475+
476+
@cached_property
477+
def custom(self) -> Custom | None:
478+
return self._metadata('custom', None)
479+
480+
def _metadata[V](self, key: str, default: V) -> V:
482481
try:
483-
metadata = self.field.metadata['azul']
482+
return self.field.metadata['azul'][key]
484483
except KeyError:
484+
return default
485+
486+
def handle(self, x: str) -> T:
487+
if self.custom is None:
485488
return self._handle(x, self._reify(self.field.type))
486489
else:
487-
return self._custom(x, metadata)
490+
return self._custom(x)
488491

489492
def _owner(self) -> type:
490493
"""
@@ -551,6 +554,11 @@ def _handle(self, x: str, field_type: Any):
551554
return self._dict(x, key_type, value_type)
552555
raise TypeError('Unserializable field', field_type, self.field)
553556

557+
@property
558+
@abstractmethod
559+
def enabled(self) -> bool:
560+
raise NotImplementedError
561+
554562
@abstractmethod
555563
def _primitive(self, x: str, field_type: type) -> T:
556564
raise NotImplementedError
@@ -576,11 +584,15 @@ def _dict(self, x: str, key_type: type, value_type: type) -> T:
576584
raise NotImplementedError
577585

578586
@abstractmethod
579-
def _custom(self, x: str, metadata: 'SerializableAttrs.Metadata') -> T:
587+
def _custom(self, x: str) -> T:
580588
raise NotImplementedError
581589

582590
class Deserializer(Strategy[Source]):
583591

592+
@property
593+
def enabled(self) -> bool:
594+
return self.custom is None or self.custom['from_json'] is not None
595+
584596
def _optional(self, x: str, field_type: type) -> Source:
585597
return [
586598
f'if {x} is not None:', self._handle(x, field_type)
@@ -632,15 +644,20 @@ def _dict(self, x: str, key_type: type, value_type: type) -> Source:
632644
f'{x} = {d}'
633645
]
634646

635-
def _custom(self, x: str, metadata: 'SerializableAttrs.Metadata') -> Source:
647+
def _custom(self, x: str) -> Source:
636648
var_name = self.field.name + '_from_json'
637-
self.globals[var_name] = not_none(metadata['from_json'])
649+
from_json = not_none(not_none(self.custom)['from_json'])
650+
self.globals[var_name] = from_json
638651
return [
639652
f'{x} = {var_name}({x})'
640653
]
641654

642655
class Serializer(Strategy[str]):
643656

657+
@property
658+
def enabled(self) -> bool:
659+
return self.custom is None or self.custom['to_json'] is not None
660+
644661
def _primitive(self, x: str, field_type: type) -> str:
645662
return x
646663

@@ -665,32 +682,34 @@ def _dict(self, x: str, key_type: type, value_type: type) -> str:
665682
k_, v_ = self._handle(k, key_type), self._handle(v, value_type)
666683
return f'{{{k_}: {v_} for {k}, {v} in x.items()}}'
667684

668-
def _custom(self, x: str, metadata: 'SerializableAttrs.Metadata') -> str:
685+
def _custom(self, x: str) -> str:
686+
to_json = not_none(not_none(self.custom)['to_json'])
669687
var_name = self.field.name + '_to_json'
670-
self.globals[var_name] = not_none(metadata['to_json'])
688+
self.globals[var_name] = to_json
671689
return f'{var_name}({x})'
672690

673691

674-
def serializable[T: attrs.Attribute](field: T,
675-
from_json: Callable[[AnyJSON], Any],
676-
to_json: Callable[[Any], AnyJSON]) -> T:
692+
def serializable[T: attrs.Attribute](field: T | None = None,
693+
*,
694+
from_json: FromJSON,
695+
to_json: ToJSON) -> T:
677696
"""
678697
Use the provided callables to (de)serialize values of the given field,
679698
instead of generating them.
680699
681700
>>> @attrs.frozen
682701
... class Foo(SerializableAttrs):
683-
... x: set[str] = serializable(attrs.field(), to_json=sorted, from_json=set)
702+
... x: set[str] = serializable(to_json=sorted, from_json=set)
684703
685704
>>> Foo(x={'b','a'}).to_json()
686705
{'x': ['a', 'b']}
687706
688707
>>> Foo.from_json({'x': ['a']})
689708
Foo(x={'a'})
690709
"""
691-
field.metadata['azul'] = SerializableAttrs.Metadata(from_json=from_json,
692-
to_json=to_json)
693-
return field
710+
custom = SerializableAttrs.Strategy.Custom(from_json=from_json,
711+
to_json=to_json)
712+
return _set_field_metadata(field, 'custom', custom)
694713

695714

696715
def not_serializable[T: attrs.Attribute](field: T) -> T:
@@ -710,6 +729,14 @@ def not_serializable[T: attrs.Attribute](field: T) -> T:
710729
>>> Foo.from_json({})
711730
Foo(x=42)
712731
"""
713-
field.metadata['azul'] = SerializableAttrs.Metadata(from_json=None,
714-
to_json=None)
732+
custom = SerializableAttrs.Strategy.Custom(from_json=None,
733+
to_json=None)
734+
return _set_field_metadata(field, 'custom', custom)
735+
736+
737+
def _set_field_metadata[T: attrs.Attribute](field: T | None, key, value):
738+
if field is None:
739+
field = attrs.field()
740+
metadata = field.metadata.setdefault('azul', {})
741+
metadata[key] = value
715742
return field

0 commit comments

Comments
 (0)