diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index ecacda3c..7d7ebb5e 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -14,6 +14,7 @@ ClassVar, Dict, List, + Literal, Mapping, Optional, Sequence, @@ -141,10 +142,10 @@ def embedded(cls): def is_supported_container_type(typ: Optional[type]) -> bool: # TODO: Wait, why don't we support indexing sets? - if typ == list or typ == tuple: + if typ == list or typ == tuple or typ == Literal: return True unwrapped = get_origin(typ) - return unwrapped == list or unwrapped == tuple + return unwrapped == list or unwrapped == tuple or unwrapped == Literal def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]): @@ -1414,6 +1415,8 @@ def outer_type_or_annotation(field): if not isinstance(field.annotation, type): raise AttributeError(f"could not extract outer type from field {field}") return field.annotation + elif get_origin(field.annotation) == Literal: + return str else: return field.annotation.__args__[0] @@ -2057,21 +2060,33 @@ def schema_for_type( # find any values marked as indexed. if is_container_type and not is_vector: field_type = get_origin(typ) - embedded_cls = get_args(typ) - if not embedded_cls: - log.warning( - "Model %s defined an empty list or tuple field: %s", cls, name + if field_type == Literal: + path = f"{json_path}.{name}" + return cls.schema_for_type( + path, + name, + name_prefix, + str, + field_info, + parent_type=field_type, + ) + else: + embedded_cls = get_args(typ) + if not embedded_cls: + log.warning( + "Model %s defined an empty list or tuple field: %s", cls, name + ) + return "" + path = f"{json_path}.{name}[*]" + embedded_cls = embedded_cls[0] + return cls.schema_for_type( + path, + name, + name_prefix, + embedded_cls, + field_info, + parent_type=field_type, ) - return "" - embedded_cls = embedded_cls[0] - return cls.schema_for_type( - f"{json_path}.{name}[*]", - name, - name_prefix, - embedded_cls, - field_info, - parent_type=field_type, - ) elif field_is_model: name_prefix = f"{name_prefix}_{name}" if name_prefix else name sub_fields = [] diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index de4bdb89..95339720 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -917,3 +917,25 @@ class TestUpdate(HashModel): rematerialized = await TestUpdate.find(TestUpdate.pk == t.pk).first() assert rematerialized.age == 34 + + +@py_test_mark_asyncio +async def test_literals(): + from typing import Literal + + class TestLiterals(HashModel): + flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") + + schema = TestLiterals.redisearch_schema() + + key_prefix = TestLiterals.make_key( + TestLiterals._meta.primary_key_pattern.format(pk="") + ) + assert schema == ( + f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | flavor TAG SEPARATOR |" + ) + await Migrator().run() + item = TestLiterals(flavor="pumpkin") + await item.save() + rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first() + assert rematerialized.pk == item.pk diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 47cfd2f8..24f0f627 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1098,6 +1098,7 @@ class ModelWithIntPk(JsonModel): m = await ModelWithIntPk.find(ModelWithIntPk.my_id == 42).first() assert m.my_id == 42 + @py_test_mark_asyncio async def test_pagination(): class Test(JsonModel): @@ -1121,3 +1122,26 @@ async def get_page(cls, offset, limit): res = await Test.get_page(10, 30) assert len(res) == 30 assert res[0].num == 10 + + +@py_test_mark_asyncio +async def test_literals(): + from typing import Literal + + class TestLiterals(JsonModel): + flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") + + schema = TestLiterals.redisearch_schema() + + key_prefix = TestLiterals.make_key( + TestLiterals._meta.primary_key_pattern.format(pk="") + ) + assert schema == ( + f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | " + "$.flavor AS flavor TAG SEPARATOR |" + ) + await Migrator().run() + item = TestLiterals(flavor="pumpkin") + await item.save() + rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first() + assert rematerialized.pk == item.pk