Skip to content

Commit c732d83

Browse files
authored
Merge pull request #348 from sebbekarlsson/reference-field-callback
Added label_modifier option on ReferenceField conversion
2 parents 60409e2 + 6475cd5 commit c732d83

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

flask_mongoengine/wtf/fields.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ def __init__(
4646
label_attr="",
4747
allow_blank=False,
4848
blank_text="---",
49+
label_modifier=None,
4950
**kwargs,
5051
):
5152

5253
super(QuerySetSelectField, self).__init__(label, validators, **kwargs)
5354
self.label_attr = label_attr
5455
self.allow_blank = allow_blank
5556
self.blank_text = blank_text
57+
self.label_modifier = label_modifier
5658
self.queryset = queryset
5759

5860
def iter_choices(self):
@@ -64,7 +66,12 @@ def iter_choices(self):
6466

6567
self.queryset.rewind()
6668
for obj in self.queryset:
67-
label = self.label_attr and getattr(obj, self.label_attr) or obj
69+
label = (
70+
self.label_modifier(obj)
71+
if self.label_modifier
72+
else (self.label_attr and getattr(obj, self.label_attr) or obj)
73+
)
74+
6875
if isinstance(self.data, list):
6976
selected = obj in self.data
7077
else:

flask_mongoengine/wtf/orm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def convert(self, model, field, field_args):
8989
if hasattr(field, "to_form_field"):
9090
return field.to_form_field(model, kwargs)
9191

92+
if hasattr(field, "field") and type(field.field) == ReferenceField:
93+
kwargs["label_modifier"] = getattr(
94+
model, field.name + "_label_modifier", None
95+
)
96+
9297
if ftype in self.converters:
9398
return self.converters[ftype](model, field, kwargs)
9499

tests/test_forms.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,31 @@ class Post(db.Document):
534534
form = PostForm()
535535
self.assertTrue("content-text" in "%s" % form.content.text)
536536

537+
def test_form_label_modifier(self):
538+
with self.app.test_request_context("/"):
539+
db = self.db
540+
541+
class FoodItem(db.Document):
542+
title = db.StringField()
543+
544+
class FoodStore(db.Document):
545+
title = db.StringField(max_length=120, required=True)
546+
food_items = db.ListField(db.ReferenceField(FoodItem))
547+
548+
def food_items_label_modifier(obj):
549+
return obj.title
550+
551+
fruit_names = ["banana", "apple", "pear"]
552+
553+
food_items = [FoodItem(title=name).save() for name in fruit_names]
554+
555+
FoodStore(title="John's fruits", food_items=food_items).save()
556+
557+
FoodStoreForm = model_form(FoodStore)
558+
form = FoodStoreForm()
559+
560+
assert [obj.label.text for obj in form.food_items] == fruit_names
561+
537562

538563
if __name__ == "__main__":
539564
unittest.main()

0 commit comments

Comments
 (0)