Skip to content

Commit 6531c5c

Browse files
authored
Merge pull request #121 from andruten/expansion-depth-recursive
Expansion depth and recursive expansion
2 parents 6eecc83 + 0d3da72 commit 6531c5c

File tree

6 files changed

+177
-11
lines changed

6 files changed

+177
-11
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ local-dev.txt
55
dist/
66
MANIFEST
77
.mypy_cache/
8+
.idea/
89
.vscode/
910
drf_flex_fields.egg-info/
1011
venv.sh
11-
.venv
12+
.venv
13+
venv/

README.md

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -483,19 +483,30 @@ class PersonSerializer(FlexFieldsModelSerializer):
483483

484484
Parameter names and wildcard values can be configured within a Django setting, named `REST_FLEX_FIELDS`.
485485

486-
| Option | Description | Default |
487-
| --------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | --------------- |
488-
| EXPAND_PARAM | The name of the parameter with the fields to be expanded | `"expand"` |
489-
| FIELDS_PARAM | The name of the parameter with the fields to be included (others will be omitted) | `"fields"` |
490-
| OMIT_PARAM | The name of the parameter with the fields to be omitted | `"omit"` |
491-
| WILDCARD_VALUES | List of values that stand in for all field names. Can be used with the `fields` and `expand` parameters. <br><br>When used with `expand`, a wildcard value will trigger the expansion of all `expandable_fields` at a given level.<br><br>When used with `fields`, all fields are included at a given level. For example, you could pass `fields=name,state.*` if you have a city resource with a nested state in order to expand only the city's name field and all of the state's fields. <br><br>To disable use of wildcards, set this setting to `None`. | `["*", "~all"]` |
486+
| Option | Description | Default |
487+
|-------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|-----------------|
488+
| EXPAND_PARAM | The name of the parameter with the fields to be expanded | `"expand"` |
489+
| MAXIMUM_EXPANSION_DEPTH | The number of maximum depth permitted expansion | `None` |
490+
| FIELDS_PARAM | The name of the parameter with the fields to be included (others will be omitted) | `"fields"` |
491+
| OMIT_PARAM | The name of the parameter with the fields to be omitted | `"omit"` |
492+
| RECURSIVE_EXPANSION_PERMITTED | If `False`, an exception is raised when a recursive pattern is found | `True` |
493+
| WILDCARD_VALUES | List of values that stand in for all field names. Can be used with the `fields` and `expand` parameters. <br><br>When used with `expand`, a wildcard value will trigger the expansion of all `expandable_fields` at a given level.<br><br>When used with `fields`, all fields are included at a given level. For example, you could pass `fields=name,state.*` if you have a city resource with a nested state in order to expand only the city's name field and all of the state's fields. <br><br>To disable use of wildcards, set this setting to `None`. | `["*", "~all"]` |
492494

493495
For example, if you want your API to work a bit more like [JSON API](https://jsonapi.org/format/#fetching-includes), you could do:
494496

495497
```python
496498
REST_FLEX_FIELDS = {"EXPAND_PARAM": "include"}
497499
```
498500

501+
### Defining expansion and recursive limits at serializer level
502+
503+
`maximum_expansion_depth` property can be overridden at serializer level. It can be configured as `int` or `None`.
504+
505+
`recursive_expansion_permitted` property can be overridden at serializer level. It must be `bool`.
506+
507+
Both settings raise `serializers.ValidationError` when conditions are met but exceptions can be overridden in `_recursive_expansion_found` and `_expansion_depth_exceeded` methods.
508+
509+
499510
## Serializer Introspection
500511

501512
When using an instance of `FlexFieldsModelSerializer`, you can examine the property `expanded_fields` to discover which fields, if any, have been dynamically expanded.

rest_flex_fields/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
EXPAND_PARAM = FLEX_FIELDS_OPTIONS.get("EXPAND_PARAM", "expand")
66
FIELDS_PARAM = FLEX_FIELDS_OPTIONS.get("FIELDS_PARAM", "fields")
77
OMIT_PARAM = FLEX_FIELDS_OPTIONS.get("OMIT_PARAM", "omit")
8+
MAXIMUM_EXPANSION_DEPTH = FLEX_FIELDS_OPTIONS.get("MAXIMUM_EXPANSION_DEPTH", None)
9+
RECURSIVE_EXPANSION_PERMITTED = FLEX_FIELDS_OPTIONS.get("RECURSIVE_EXPANSION_PERMITTED", True)
810

911
WILDCARD_ALL = "~all"
1012
WILDCARD_ASTERISK = "*"
@@ -20,9 +22,12 @@
2022
assert isinstance(FIELDS_PARAM, str), "'FIELDS_PARAM' should be a string"
2123
assert isinstance(OMIT_PARAM, str), "'OMIT_PARAM' should be a string"
2224

23-
if type(WILDCARD_VALUES) not in (list, None):
25+
if type(WILDCARD_VALUES) not in (list, type(None)):
2426
raise ValueError("'WILDCARD_EXPAND_VALUES' should be a list of strings or None")
25-
27+
if type(MAXIMUM_EXPANSION_DEPTH) not in (int, type(None)):
28+
raise ValueError("'MAXIMUM_EXPANSION_DEPTH' should be a int or None")
29+
if type(RECURSIVE_EXPANSION_PERMITTED) is not bool:
30+
raise ValueError("'RECURSIVE_EXPANSION_PERMITTED' should be a bool")
2631

2732
from .utils import *
2833
from .serializers import FlexFieldsModelSerializer

rest_flex_fields/serializers.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib
33
from typing import List, Optional, Tuple
44

5+
from django.conf import settings
56
from rest_framework import serializers
67

78
from rest_flex_fields import (
@@ -22,6 +23,8 @@ class FlexFieldsSerializerMixin(object):
2223
"""
2324

2425
expandable_fields = {}
26+
maximum_expansion_depth: Optional[int] = None
27+
recursive_expansion_permitted: Optional[bool] = None
2528

2629
def __init__(self, *args, **kwargs):
2730
expand = list(kwargs.pop(EXPAND_PARAM, []))
@@ -58,6 +61,21 @@ def __init__(self, *args, **kwargs):
5861
+ self._flex_options_rep_only["omit"],
5962
}
6063

64+
def get_maximum_expansion_depth(self) -> Optional[int]:
65+
"""
66+
Defined at serializer level or based on MAXIMUM_EXPANSION_DEPTH setting
67+
"""
68+
return self.maximum_expansion_depth or settings.REST_FLEX_FIELDS.get("MAXIMUM_EXPANSION_DEPTH", None)
69+
70+
def get_recursive_expansion_permitted(self) -> bool:
71+
"""
72+
Defined at serializer level or based on RECURSIVE_EXPANSION_PERMITTED setting
73+
"""
74+
if self.recursive_expansion_permitted is not None:
75+
return self.recursive_expansion_permitted
76+
else:
77+
return settings.REST_FLEX_FIELDS.get("RECURSIVE_EXPANSION_PERMITTED", True)
78+
6179
def to_representation(self, instance):
6280
if not self._flex_fields_rep_applied:
6381
self.apply_flex_fields(self.fields, self._flex_options_rep_only)
@@ -264,11 +282,63 @@ def _get_query_param_value(self, field: str) -> List[str]:
264282
if not values:
265283
values = self.context["request"].query_params.getlist("{}[]".format(field))
266284

285+
for expand_path in values:
286+
self._validate_recursive_expansion(expand_path)
287+
self._validate_expansion_depth(expand_path)
288+
267289
if values and len(values) == 1:
268290
return values[0].split(",")
269291

270292
return values or []
271293

294+
def _split_expand_field(self, expand_path: str) -> List[str]:
295+
return expand_path.split('.')
296+
297+
def recursive_expansion_not_permitted(self):
298+
"""
299+
A customized exception can be raised when recursive expansion is found, default ValidationError
300+
"""
301+
raise serializers.ValidationError(detail="Recursive expansion found")
302+
303+
def _validate_recursive_expansion(self, expand_path: str) -> None:
304+
"""
305+
Given an expand_path, a dotted-separated string,
306+
an Exception is raised when a recursive
307+
expansion is detected.
308+
Only applies when REST_FLEX_FIELDS["RECURSIVE_EXPANSION"] setting is False.
309+
"""
310+
recursive_expansion_permitted = self.get_recursive_expansion_permitted()
311+
if recursive_expansion_permitted is True:
312+
return
313+
314+
expansion_path = self._split_expand_field(expand_path)
315+
expansion_length = len(expansion_path)
316+
expansion_length_unique = len(set(expansion_path))
317+
if expansion_length != expansion_length_unique:
318+
self.recursive_expansion_not_permitted()
319+
320+
def expansion_depth_exceeded(self):
321+
"""
322+
A customized exception can be raised when expansion depth is found, default ValidationError
323+
"""
324+
raise serializers.ValidationError(detail="Expansion depth exceeded")
325+
326+
def _validate_expansion_depth(self, expand_path: str) -> None:
327+
"""
328+
Given an expand_path, a dotted-separated string,
329+
an Exception is raised when expansion level is
330+
greater than the `expansion_depth` property configuration.
331+
Only applies when REST_FLEX_FIELDS["EXPANSION_DEPTH"] setting is set
332+
or serializer has its own expansion configuration through default_expansion_depth attribute.
333+
"""
334+
maximum_expansion_depth = self.get_maximum_expansion_depth()
335+
if maximum_expansion_depth is None:
336+
return
337+
338+
expansion_path = self._split_expand_field(expand_path)
339+
if len(expansion_path) > maximum_expansion_depth:
340+
self.expansion_depth_exceeded()
341+
272342
def _get_permitted_expands_from_query_param(self, expand_param: str) -> List[str]:
273343
"""
274344
If a list of permitted_expands has been passed to context,

tests/test_flex_fields_model_serializer.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from unittest import TestCase
2+
from unittest.mock import patch, PropertyMock
23

4+
from django.test import override_settings
35
from django.utils.datastructures import MultiValueDict
6+
from rest_framework import serializers
7+
48
from rest_flex_fields import FlexFieldsModelSerializer
59

610

711
class MockRequest(object):
8-
def __init__(self, query_params=MultiValueDict(), method="GET"):
12+
def __init__(self, query_params=None, method="GET"):
13+
if query_params is None:
14+
query_params = MultiValueDict()
915
self.query_params = query_params
1016
self.method = method
1117

@@ -178,3 +184,73 @@ def test_import_serializer_class(self):
178184

179185
def test_make_expanded_field_serializer(self):
180186
pass
187+
188+
@override_settings(REST_FLEX_FIELDS={"RECURSIVE_EXPANSION_PERMITTED": False})
189+
def test_recursive_expansion(self):
190+
with self.assertRaises(serializers.ValidationError):
191+
FlexFieldsModelSerializer(
192+
context={
193+
"request": MockRequest(
194+
method="GET", query_params=MultiValueDict({"expand": ["dog.leg.dog"]})
195+
)
196+
}
197+
)
198+
199+
@patch('rest_flex_fields.FlexFieldsModelSerializer.recursive_expansion_permitted', new_callable=PropertyMock)
200+
def test_recursive_expansion_serializer_level(self, mock_recursive_expansion_permitted):
201+
mock_recursive_expansion_permitted.return_value = False
202+
203+
with self.assertRaises(serializers.ValidationError):
204+
FlexFieldsModelSerializer(
205+
context={
206+
"request": MockRequest(
207+
method="GET", query_params=MultiValueDict({"expand": ["dog.leg.dog"]})
208+
)
209+
}
210+
)
211+
212+
@override_settings(REST_FLEX_FIELDS={"MAXIMUM_EXPANSION_DEPTH": 3})
213+
def test_expansion_depth(self):
214+
serializer = FlexFieldsModelSerializer(
215+
context={
216+
"request": MockRequest(
217+
method="GET", query_params=MultiValueDict({"expand": ["dog.leg.paws"]})
218+
)
219+
}
220+
)
221+
self.assertEqual(serializer._flex_options_all["expand"], ["dog.leg.paws"])
222+
223+
@override_settings(REST_FLEX_FIELDS={"MAXIMUM_EXPANSION_DEPTH": 2})
224+
def test_expansion_depth_exception(self):
225+
with self.assertRaises(serializers.ValidationError):
226+
FlexFieldsModelSerializer(
227+
context={
228+
"request": MockRequest(
229+
method="GET", query_params=MultiValueDict({"expand": ["dog.leg.paws"]})
230+
)
231+
}
232+
)
233+
234+
@patch('rest_flex_fields.FlexFieldsModelSerializer.maximum_expansion_depth', new_callable=PropertyMock)
235+
def test_expansion_depth_serializer_level(self, mock_maximum_expansion_depth):
236+
mock_maximum_expansion_depth.return_value = 3
237+
serializer = FlexFieldsModelSerializer(
238+
context={
239+
"request": MockRequest(
240+
method="GET", query_params=MultiValueDict({"expand": ["dog.leg.paws"]})
241+
)
242+
}
243+
)
244+
self.assertEqual(serializer._flex_options_all["expand"], ["dog.leg.paws"])
245+
246+
@patch('rest_flex_fields.FlexFieldsModelSerializer.maximum_expansion_depth', new_callable=PropertyMock)
247+
def test_expansion_depth_serializer_level_exception(self, mock_maximum_expansion_depth):
248+
mock_maximum_expansion_depth.return_value = 2
249+
with self.assertRaises(serializers.ValidationError):
250+
FlexFieldsModelSerializer(
251+
context={
252+
"request": MockRequest(
253+
method="GET", query_params=MultiValueDict({"expand": ["dog.leg.paws"]})
254+
)
255+
}
256+
)

tests/test_serializer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111

1212
class MockRequest(object):
13-
def __init__(self, query_params={}, method="GET"):
13+
def __init__(self, query_params=None, method="GET"):
14+
if query_params is None:
15+
query_params = {}
1416
self.query_params = query_params
1517
self.method = method
1618

0 commit comments

Comments
 (0)