Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8aa4469

Browse files
authoredJul 12, 2023
Merge branch 'main' into ck-020
2 parents ba9c0dc + b12b432 commit 8aa4469

File tree

5 files changed

+204
-23
lines changed

5 files changed

+204
-23
lines changed
 

‎.github/workflows/spellcheck.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
- name: Checkout
99
uses: actions/checkout@v3
1010
- name: Check Spelling
11-
uses: rojopolis/spellcheck-github-actions@0.29.0
11+
uses: rojopolis/spellcheck-github-actions@0.33.0
1212
with:
1313
config_path: .github/spellcheck-settings.yml
1414
task_name: Markdown

‎aredis_om/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
FindQuery,
99
HashModel,
1010
JsonModel,
11+
VectorFieldOptions,
12+
KNNExpression,
1113
NotFoundError,
1214
QueryNotSupportedError,
1315
QuerySyntaxError,

‎aredis_om/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
Field,
55
HashModel,
66
JsonModel,
7+
VectorFieldOptions,
8+
KNNExpression,
79
NotFoundError,
810
RedisModel,
911
)

‎aredis_om/model/model.py

Lines changed: 199 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
model_registry = {}
4949
_T = TypeVar("_T")
50+
Model = TypeVar("Model", bound="RedisModel")
5051
log = logging.getLogger(__name__)
5152
escaper = TokenEscaper()
5253

@@ -252,6 +253,24 @@ def tree(self):
252253
return render_tree(self)
253254

254255

256+
@dataclasses.dataclass
257+
class KNNExpression:
258+
k: int
259+
vector_field: ModelField
260+
reference_vector: bytes
261+
262+
def __str__(self):
263+
return f"KNN $K @{self.vector_field.name} $knn_ref_vector"
264+
265+
@property
266+
def query_params(self) -> Dict[str, Union[str, bytes]]:
267+
return {"K": str(self.k), "knn_ref_vector": self.reference_vector}
268+
269+
@property
270+
def score_field(self) -> str:
271+
return f"__{self.vector_field.name}_score"
272+
273+
255274
ExpressionOrNegated = Union[Expression, NegatedExpression]
256275

257276

@@ -349,8 +368,9 @@ def __init__(
349368
self,
350369
expressions: Sequence[ExpressionOrNegated],
351370
model: Type["RedisModel"],
371+
knn: Optional[KNNExpression] = None,
352372
offset: int = 0,
353-
limit: int = DEFAULT_PAGE_SIZE,
373+
limit: Optional[int] = None,
354374
page_size: int = DEFAULT_PAGE_SIZE,
355375
sort_fields: Optional[List[str]] = None,
356376
nocontent: bool = False,
@@ -364,13 +384,16 @@ def __init__(
364384

365385
self.expressions = expressions
366386
self.model = model
387+
self.knn = knn
367388
self.offset = offset
368-
self.limit = limit
389+
self.limit = limit or (self.knn.k if self.knn else DEFAULT_PAGE_SIZE)
369390
self.page_size = page_size
370391
self.nocontent = nocontent
371392

372393
if sort_fields:
373394
self.sort_fields = self.validate_sort_fields(sort_fields)
395+
elif self.knn:
396+
self.sort_fields = [self.knn.score_field]
374397
else:
375398
self.sort_fields = []
376399

@@ -425,11 +448,26 @@ def query(self):
425448
if self._query:
426449
return self._query
427450
self._query = self.resolve_redisearch_query(self.expression)
451+
if self.knn:
452+
self._query = (
453+
self._query
454+
if self._query.startswith("(") or self._query == "*"
455+
else f"({self._query})"
456+
) + f"=>[{self.knn}]"
428457
return self._query
429458

459+
@property
460+
def query_params(self):
461+
params: List[Union[str, bytes]] = []
462+
if self.knn:
463+
params += [attr for kv in self.knn.query_params.items() for attr in kv]
464+
return params
465+
430466
def validate_sort_fields(self, sort_fields: List[str]):
431467
for sort_field in sort_fields:
432468
field_name = sort_field.lstrip("-")
469+
if self.knn and field_name == self.knn.score_field:
470+
continue
433471
if field_name not in self.model.__fields__:
434472
raise QueryNotSupportedError(
435473
f"You tried sort by {field_name}, but that field "
@@ -728,10 +766,27 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
728766
return result
729767

730768
async def execute(self, exhaust_results=True, return_raw_result=False):
731-
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
769+
args: List[Union[str, bytes]] = [
770+
"FT.SEARCH",
771+
self.model.Meta.index_name,
772+
self.query,
773+
*self.pagination,
774+
]
732775
if self.sort_fields:
733776
args += self.resolve_redisearch_sort_fields()
734777

778+
if self.query_params:
779+
args += ["PARAMS", str(len(self.query_params))] + self.query_params
780+
781+
if self.knn:
782+
# Ensure DIALECT is at least 2
783+
if "DIALECT" not in args:
784+
args += ["DIALECT", "2"]
785+
else:
786+
i_dialect = args.index("DIALECT") + 1
787+
if int(args[i_dialect]) < 2:
788+
args[i_dialect] = "2"
789+
735790
if self.nocontent:
736791
args.append("NOCONTENT")
737792

@@ -917,11 +972,13 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
917972
sortable = kwargs.pop("sortable", Undefined)
918973
index = kwargs.pop("index", Undefined)
919974
full_text_search = kwargs.pop("full_text_search", Undefined)
975+
vector_options = kwargs.pop("vector_options", None)
920976
super().__init__(default=default, **kwargs)
921977
self.primary_key = primary_key
922978
self.sortable = sortable
923979
self.index = index
924980
self.full_text_search = full_text_search
981+
self.vector_options = vector_options
925982

926983

927984
class RelationshipInfo(Representation):
@@ -935,6 +992,94 @@ def __init__(
935992
self.link_model = link_model
936993

937994

995+
@dataclasses.dataclass
996+
class VectorFieldOptions:
997+
class ALGORITHM(Enum):
998+
FLAT = "FLAT"
999+
HNSW = "HNSW"
1000+
1001+
class TYPE(Enum):
1002+
FLOAT32 = "FLOAT32"
1003+
FLOAT64 = "FLOAT64"
1004+
1005+
class DISTANCE_METRIC(Enum):
1006+
L2 = "L2"
1007+
IP = "IP"
1008+
COSINE = "COSINE"
1009+
1010+
algorithm: ALGORITHM
1011+
type: TYPE
1012+
dimension: int
1013+
distance_metric: DISTANCE_METRIC
1014+
1015+
# Common optional parameters
1016+
initial_cap: Optional[int] = None
1017+
1018+
# Optional parameters for FLAT
1019+
block_size: Optional[int] = None
1020+
1021+
# Optional parameters for HNSW
1022+
m: Optional[int] = None
1023+
ef_construction: Optional[int] = None
1024+
ef_runtime: Optional[int] = None
1025+
epsilon: Optional[float] = None
1026+
1027+
@staticmethod
1028+
def flat(
1029+
type: TYPE,
1030+
dimension: int,
1031+
distance_metric: DISTANCE_METRIC,
1032+
initial_cap: Optional[int] = None,
1033+
block_size: Optional[int] = None,
1034+
):
1035+
return VectorFieldOptions(
1036+
algorithm=VectorFieldOptions.ALGORITHM.FLAT,
1037+
type=type,
1038+
dimension=dimension,
1039+
distance_metric=distance_metric,
1040+
initial_cap=initial_cap,
1041+
block_size=block_size,
1042+
)
1043+
1044+
@staticmethod
1045+
def hnsw(
1046+
type: TYPE,
1047+
dimension: int,
1048+
distance_metric: DISTANCE_METRIC,
1049+
initial_cap: Optional[int] = None,
1050+
m: Optional[int] = None,
1051+
ef_construction: Optional[int] = None,
1052+
ef_runtime: Optional[int] = None,
1053+
epsilon: Optional[float] = None,
1054+
):
1055+
return VectorFieldOptions(
1056+
algorithm=VectorFieldOptions.ALGORITHM.HNSW,
1057+
type=type,
1058+
dimension=dimension,
1059+
distance_metric=distance_metric,
1060+
initial_cap=initial_cap,
1061+
m=m,
1062+
ef_construction=ef_construction,
1063+
ef_runtime=ef_runtime,
1064+
epsilon=epsilon,
1065+
)
1066+
1067+
@property
1068+
def schema(self):
1069+
attr = []
1070+
for k, v in vars(self).items():
1071+
if k == "algorithm" or v is None:
1072+
continue
1073+
attr.extend(
1074+
[
1075+
k.upper() if k != "dimension" else "DIM",
1076+
str(v) if not isinstance(v, Enum) else v.name,
1077+
]
1078+
)
1079+
1080+
return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr)
1081+
1082+
9381083
def Field(
9391084
default: Any = Undefined,
9401085
*,
@@ -964,6 +1109,7 @@ def Field(
9641109
sortable: Union[bool, UndefinedType] = Undefined,
9651110
index: Union[bool, UndefinedType] = Undefined,
9661111
full_text_search: Union[bool, UndefinedType] = Undefined,
1112+
vector_options: Optional[VectorFieldOptions] = None,
9671113
schema_extra: Optional[Dict[str, Any]] = None,
9681114
) -> Any:
9691115
current_schema_extra = schema_extra or {}
@@ -991,6 +1137,7 @@ def Field(
9911137
sortable=sortable,
9921138
index=index,
9931139
full_text_search=full_text_search,
1140+
vector_options=vector_options,
9941141
**current_schema_extra,
9951142
)
9961143
field_info._validate()
@@ -1083,6 +1230,10 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
10831230
new_class._meta.primary_key = PrimaryKey(
10841231
name=field_name, field=field
10851232
)
1233+
if field.field_info.vector_options:
1234+
score_attr = f"_{field_name}_score"
1235+
setattr(new_class, score_attr, None)
1236+
new_class.__annotations__[score_attr] = Union[float, None]
10861237

10871238
if not getattr(new_class._meta, "global_key_prefix", None):
10881239
new_class._meta.global_key_prefix = getattr(
@@ -1160,16 +1311,16 @@ async def delete(
11601311
return await cls._delete(db, cls.make_primary_key(pk))
11611312

11621313
@classmethod
1163-
async def get(cls, pk: Any) -> "RedisModel":
1314+
async def get(cls: Type["Model"], pk: Any) -> "Model":
11641315
raise NotImplementedError
11651316

11661317
async def update(self, **field_values):
11671318
"""Update this model instance with the specified key-value pairs."""
11681319
raise NotImplementedError
11691320

11701321
async def save(
1171-
self, pipeline: Optional[redis.client.Pipeline] = None
1172-
) -> "RedisModel":
1322+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1323+
) -> "Model":
11731324
raise NotImplementedError
11741325

11751326
async def expire(
@@ -1216,8 +1367,12 @@ def db(cls):
12161367
return cls._meta.database
12171368

12181369
@classmethod
1219-
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
1220-
return FindQuery(expressions=expressions, model=cls)
1370+
def find(
1371+
cls,
1372+
*expressions: Union[Any, Expression],
1373+
knn: Optional[KNNExpression] = None,
1374+
) -> FindQuery:
1375+
return FindQuery(expressions=expressions, knn=knn, model=cls)
12211376

12221377
@classmethod
12231378
def from_redis(cls, res: Any):
@@ -1237,7 +1392,7 @@ def to_string(s):
12371392
for i in range(1, len(res), step):
12381393
if res[i + offset] is None:
12391394
continue
1240-
fields = dict(
1395+
fields: Dict[str, str] = dict(
12411396
zip(
12421397
map(to_string, res[i + offset][::2]),
12431398
map(to_string, res[i + offset][1::2]),
@@ -1247,6 +1402,9 @@ def to_string(s):
12471402
if fields.get("$"):
12481403
json_fields = json.loads(fields.pop("$"))
12491404
doc = cls(**json_fields)
1405+
for k, v in fields.items():
1406+
if k.startswith("__") and k.endswith("_score"):
1407+
setattr(doc, k[1:], float(v))
12501408
else:
12511409
doc = cls(**fields)
12521410

@@ -1266,11 +1424,11 @@ def get_annotations(cls):
12661424

12671425
@classmethod
12681426
async def add(
1269-
cls,
1270-
models: Sequence["RedisModel"],
1427+
cls: Type["Model"],
1428+
models: Sequence["Model"],
12711429
pipeline: Optional[redis.client.Pipeline] = None,
12721430
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
1273-
) -> Sequence["RedisModel"]:
1431+
) -> Sequence["Model"]:
12741432
db = cls._get_db(pipeline, bulk=True)
12751433

12761434
for model in models:
@@ -1345,8 +1503,8 @@ def __init_subclass__(cls, **kwargs):
13451503
)
13461504

13471505
async def save(
1348-
self, pipeline: Optional[redis.client.Pipeline] = None
1349-
) -> "HashModel":
1506+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1507+
) -> "Model":
13501508
self.check()
13511509
db = self._get_db(pipeline)
13521510

@@ -1368,7 +1526,7 @@ async def all_pks(cls): # type: ignore
13681526
)
13691527

13701528
@classmethod
1371-
async def get(cls, pk: Any) -> "HashModel":
1529+
async def get(cls: Type["Model"], pk: Any) -> "Model":
13721530
document = await cls.db().hgetall(cls.make_primary_key(pk))
13731531
if not document:
13741532
raise NotFoundError
@@ -1474,7 +1632,13 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
14741632
embedded_cls = embedded_cls[0]
14751633
schema = cls.schema_for_type(name, embedded_cls, field_info)
14761634
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
1477-
schema = f"{name} NUMERIC"
1635+
vector_options: Optional[VectorFieldOptions] = getattr(
1636+
field_info, "vector_options", None
1637+
)
1638+
if vector_options:
1639+
schema = f"{name} {vector_options.schema}"
1640+
else:
1641+
schema = f"{name} NUMERIC"
14781642
elif issubclass(typ, str):
14791643
if getattr(field_info, "full_text_search", False) is True:
14801644
schema = (
@@ -1513,8 +1677,8 @@ def __init__(self, *args, **kwargs):
15131677
super().__init__(*args, **kwargs)
15141678

15151679
async def save(
1516-
self, pipeline: Optional[redis.client.Pipeline] = None
1517-
) -> "JsonModel":
1680+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1681+
) -> "Model":
15181682
self.check()
15191683
db = self._get_db(pipeline)
15201684

@@ -1559,7 +1723,7 @@ async def update(self, **field_values):
15591723
await self.save()
15601724

15611725
@classmethod
1562-
async def get(cls, pk: Any) -> "JsonModel":
1726+
async def get(cls: Type["Model"], pk: Any) -> "Model":
15631727
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
15641728
if document == "null":
15651729
raise NotFoundError
@@ -1623,10 +1787,22 @@ def schema_for_type(
16231787
# Not a class, probably a type annotation
16241788
field_is_model = False
16251789

1790+
vector_options: Optional[VectorFieldOptions] = getattr(
1791+
field_info, "vector_options", None
1792+
)
1793+
try:
1794+
is_vector = vector_options and any(
1795+
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
1796+
)
1797+
except IndexError:
1798+
raise RedisModelError(
1799+
f"Vector field '{name}' must be annotated as a container type"
1800+
)
1801+
16261802
# When we encounter a list or model field, we need to descend
16271803
# into the values of the list or the fields of the model to
16281804
# find any values marked as indexed.
1629-
if is_container_type:
1805+
if is_container_type and not is_vector:
16301806
field_type = get_origin(typ)
16311807
embedded_cls = get_args(typ)
16321808
if not embedded_cls:
@@ -1689,7 +1865,9 @@ def schema_for_type(
16891865
)
16901866

16911867
# TODO: GEO field
1692-
if parent_is_container_type or parent_is_model_in_container:
1868+
if is_vector and vector_options:
1869+
schema = f"{path} AS {index_field_name} {vector_options.schema}"
1870+
elif parent_is_container_type or parent_is_model_in_container:
16931871
if typ is not str:
16941872
raise RedisModelError(
16951873
"In this Preview release, list and tuple fields can only "

‎pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ python = ">=3.7,<4.0"
3939
redis = ">=3.5.3,<5.0.0"
4040
pydantic = "^1.10.2"
4141
click = "^8.0.1"
42-
pptree = "^3.1"
4342
types-redis = ">=3.5.9,<5.0.0"
4443
python-ulid = "^1.0.3"
4544
typing-extensions = "^4.4.0"

0 commit comments

Comments
 (0)
Please sign in to comment.