Skip to content

5766 Add nested field filtering and ensure defer omitted fields #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dist/
*.egg-info/
build/
.tox/
.idea
207 changes: 194 additions & 13 deletions drf_dynamic_fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Mixin to dynamically select only a subset of fields per DRF resource.
"""

import warnings

from django.conf import settings
from django.db.models import Prefetch
from django.utils.functional import cached_property


Expand All @@ -20,11 +22,11 @@ def fields(self):

A blank `fields` parameter (?fields) will remove all fields. Not
passing `fields` will pass all fields individual fields are comma
separated (?fields=id,name,url,email).
separated (?fields=id,name,url,email,teachers__age).

"""
fields = super(DynamicFieldsMixin, self).fields

fields = super(DynamicFieldsMixin, self).fields
if not hasattr(self, "_context"):
# We are being called before a request cycle
return fields
Expand Down Expand Up @@ -58,30 +60,209 @@ def fields(self):
try:
filter_fields = params.get("fields", None).split(",")
except AttributeError:
filter_fields = None
filter_fields = []

try:
omit_fields = params.get("omit", None).split(",")
except AttributeError:
omit_fields = []

# Drop any fields that are not specified in the `fields` argument.
self._flat_allow = set()
self._flat_omit = set()
self._nested_allow = {}
self._nested_omit = {}

# store top-level and nested fields specified in the `fields` argument.
for filtered_field in filter_fields:
if "__" in filtered_field:
parent, child = filtered_field.split("__", 1)
self._nested_allow.setdefault(parent, []).append(child)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this have been written as a defaultdict?

# If a nested field is allowed the related parent level field
# must also be allowed
self._flat_allow.add(parent)
else:
self._flat_allow.add(filtered_field)

# store top-level and nested fields in the `omit` argument.
for omitted_field in omit_fields:
if "__" in omitted_field:
parent, child = omitted_field.split("__", 1)
self._nested_omit.setdefault(parent, []).append(child)
else:
self._flat_omit.add(omitted_field)

# Drop top-level fields
existing = set(fields.keys())
if filter_fields is None:
# no fields param given, don't filter.
allowed = existing
if "fields" in params:
allowed = self._flat_allow
else:
allowed = set(filter(None, filter_fields))

# omit fields in the `omit` argument.
omitted = set(filter(None, omit_fields))
allowed = existing
omitted = self._flat_omit

for field in existing:

if field not in allowed:
fields.pop(field, None)

if field in omitted:
fields.pop(field, None)

# Drop omitted child fields from nested serializers
for parent, omit_list in self._nested_omit.items():
field = fields[parent]
nested_serializer = getattr(field, "child", field)
if hasattr(nested_serializer, "fields"):
for child in omit_list:
nested_serializer.fields.pop(child, None)

# Drop non-allowed child fields from the nested serializers
for parent, allow_list in self._nested_allow.items():
field = fields[parent]
nested_serializer = getattr(field, "child", field)
if hasattr(nested_serializer, "fields"):
for child_name in list(nested_serializer.fields):
if child_name not in allow_list:
nested_serializer.fields.pop(child_name, None)

return fields

def _get_disallowed_top_level_fields_to_defer(self):
"""
Determine which top-level model fields should be deferred when an explicit
fields filter is in use.
Other model fields not explicitly included in 'fields' are deferred.
"""
allow = getattr(self, "_flat_allow", None)
model = getattr(self.Meta, "model", None)
if not allow or model is None:
return []

# Filter out fields that have a database column associated with them.
field_names = [
field.name
for field in model._meta.get_fields()
if getattr(field, "concrete", False)
]
return [field_name for field_name in field_names if field_name not in allow]

def _get_disallowed_nested_level_fields_to_defer(self):
"""
Determine which nested-model fields should be deferred for each parent serializer
when an explicit fields filter is in use.
Other model nested fields not explicitly included in 'fields' are deferred.
"""
fields_to_defer = []
for parent, allow_list in getattr(self, "_nested_allow", {}).items():
field = self.fields.get(parent)
if not field:
continue

child_serializer = getattr(field, "child", field)
nested_model = getattr(child_serializer.Meta, "model", None)
if nested_model is None:
continue

# Filter out nested fields that have a database column associated
# with them.
field_names = [
field.name
for field in nested_model._meta.get_fields()
if getattr(field, "concrete", False)
]
for field_name in field_names:
if field_name not in allow_list:
fields_to_defer.append(f"{parent}__{field_name}")

return fields_to_defer

def get_deferred_model_fields(self):
"""
Returns a flat list of omitted model-fields; top-level and nested.
Ensures that parsing of "fields"/"omit" has run by accessing ".fields".
"""

# Trigger parsing of required attributes if not already set
if not all(
hasattr(self, attr)
for attr in ("_flat_omit", "_nested_omit", "_flat_allow", "_nested_allow")
):
_ = self.fields

flat_omit = getattr(self, "_flat_omit", [])
nested_omit = getattr(self, "_nested_omit", {})
deferred = []
# Set omit top-level fields to defer
deferred.extend(flat_omit)

# Set omit nested-level fields to defer
deferred.extend(
f"{parent}__{child}"
for parent, children in nested_omit.items()
for child in children
)
# Set disallowed top-level fields to defer
deferred.extend(self._get_disallowed_top_level_fields_to_defer())
# Set disallowed nested-level fields to defer
deferred.extend(self._get_disallowed_nested_level_fields_to_defer())

# Remove any duplicate fields
return list(set(deferred))


class DeferredFieldsMixin:
"""ViewSet Mixin that:
- defers top‐level model columns based on omit/fields
- builds a Prefetch for each nested relation to defer its columns too
"""

@staticmethod
def _split_deferred_fields(fields):
"""Split deferred fields into top‐level fields and nested relations."""
parent = []
nested = {}
for field in fields:
if "__" in field:
parent_field, child_field = field.split("__", 1)
nested.setdefault(parent_field, []).append(child_field)
else:
parent.append(field)
return parent, nested

@staticmethod
def _apply_nested_prefetch(qs, nested_map, serializer):
"""For each nested relation, add a Prefetch that defers its specified
fields.
"""
for parent_field, child_fields in nested_map.items():
field = serializer.fields.get(parent_field)
if not field:
continue

child_serializer = getattr(field, "child", field)
model = getattr(child_serializer.Meta, "model", None)
if not model:
continue

qs = qs.prefetch_related(
Prefetch(parent_field, queryset=model.objects.defer(*child_fields))
)
return qs

def get_queryset(self):
"""
Returns a queryset with top-level and nested fields deferred to
optimize database retrieval.
"""
qs = super().get_queryset()
# instantiate serializer so deferred fields are calculated
serializer = self.get_serializer_class()(context=self.get_serializer_context())

# split deferred fields into top-level and nested
fields = serializer.get_deferred_model_fields()
parent_fields, nested_map = self._split_deferred_fields(fields)

# defer top-level fields
if parent_fields:
qs = qs.defer(*parent_fields)

# defer nested fields via Prefetch
qs = self._apply_nested_prefetch(qs, nested_map, serializer)
return qs
41 changes: 21 additions & 20 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
from setuptools import setup

readme = open('README.rst').read()
readme = open("README.rst").read()

setup(name='drf_dynamic_fields',
version='0.4.0',
description='Dynamically return subset of Django REST Framework serializer fields',
author='Danilo Bargen',
author_email='[email protected]',
url='https://github.com/dbrgn/drf-dynamic-fields',
packages=['drf_dynamic_fields'],
zip_safe=True,
include_package_data=True,
license='MIT',
keywords='drf restframework rest_framework django_rest_framework serializers',
long_description=readme,
classifiers=[
'Development Status :: 5 - Production/Stable',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Framework :: Django',
'Environment :: Web Environment',
],
setup(
name="drf_dynamic_fields",
version="0.4.0",
description="Dynamically return subset of Django REST Framework serializer fields",
author="Danilo Bargen",
author_email="[email protected]",
url="https://github.com/dbrgn/drf-dynamic-fields",
packages=["drf_dynamic_fields"],
zip_safe=True,
include_package_data=True,
license="MIT",
keywords="drf restframework rest_framework django_rest_framework serializers",
long_description=readme,
classifiers=[
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Framework :: Django",
"Environment :: Web Environment",
],
)
10 changes: 10 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Some models for the tests. We are modelling a school.
"""

from django.db import models


Expand All @@ -14,3 +15,12 @@ class School(models.Model):

name = models.CharField(max_length=30)
teachers = models.ManyToManyField(Teacher)


class Child(models.Model):
secret = models.CharField(max_length=100)
public = models.CharField(max_length=100)


class Parent(models.Model):
child = models.ForeignKey(Child, on_delete=models.CASCADE)
17 changes: 15 additions & 2 deletions tests/serializers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
For the tests.
"""

from rest_framework import serializers

from drf_dynamic_fields import DynamicFieldsMixin
from drf_dynamic_fields import DynamicFieldsMixin, DeferredFieldsMixin

from .models import Teacher, School

Expand All @@ -29,7 +30,9 @@ def get_request_info(self, teacher):
return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk))


class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer):
class SchoolSerializer(
DynamicFieldsMixin, DeferredFieldsMixin, serializers.ModelSerializer
):
"""
Interesting enough serializer because the TeacherSerializer
will use ListSerializer due to the `many=True`
Expand All @@ -40,3 +43,13 @@ class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer):
class Meta:
model = School
fields = ("id", "teachers", "name")


class ChildSerializer(DynamicFieldsMixin, serializers.Serializer):
secret = serializers.CharField()
public = serializers.CharField()


class ParentSerializer(DynamicFieldsMixin, serializers.Serializer):
id = serializers.IntegerField()
child = ChildSerializer()
Loading