47
47
48
48
model_registry = {}
49
49
_T = TypeVar ("_T" )
50
+ Model = TypeVar ("Model" , bound = "RedisModel" )
50
51
log = logging .getLogger (__name__ )
51
52
escaper = TokenEscaper ()
52
53
@@ -252,6 +253,24 @@ def tree(self):
252
253
return render_tree (self )
253
254
254
255
256
+ @dataclasses .dataclass
257
+ class KNNExpression :
258
+ k : int
259
+ vector_field : ModelField
260
+ reference_vector : bytes
261
+
262
+ def __str__ (self ):
263
+ return f"KNN $K @{ self .vector_field .name } $knn_ref_vector"
264
+
265
+ @property
266
+ def query_params (self ) -> Dict [str , Union [str , bytes ]]:
267
+ return {"K" : str (self .k ), "knn_ref_vector" : self .reference_vector }
268
+
269
+ @property
270
+ def score_field (self ) -> str :
271
+ return f"__{ self .vector_field .name } _score"
272
+
273
+
255
274
ExpressionOrNegated = Union [Expression , NegatedExpression ]
256
275
257
276
@@ -349,8 +368,9 @@ def __init__(
349
368
self ,
350
369
expressions : Sequence [ExpressionOrNegated ],
351
370
model : Type ["RedisModel" ],
371
+ knn : Optional [KNNExpression ] = None ,
352
372
offset : int = 0 ,
353
- limit : int = DEFAULT_PAGE_SIZE ,
373
+ limit : Optional [ int ] = None ,
354
374
page_size : int = DEFAULT_PAGE_SIZE ,
355
375
sort_fields : Optional [List [str ]] = None ,
356
376
nocontent : bool = False ,
@@ -364,13 +384,16 @@ def __init__(
364
384
365
385
self .expressions = expressions
366
386
self .model = model
387
+ self .knn = knn
367
388
self .offset = offset
368
- self .limit = limit
389
+ self .limit = limit or ( self . knn . k if self . knn else DEFAULT_PAGE_SIZE )
369
390
self .page_size = page_size
370
391
self .nocontent = nocontent
371
392
372
393
if sort_fields :
373
394
self .sort_fields = self .validate_sort_fields (sort_fields )
395
+ elif self .knn :
396
+ self .sort_fields = [self .knn .score_field ]
374
397
else :
375
398
self .sort_fields = []
376
399
@@ -425,11 +448,26 @@ def query(self):
425
448
if self ._query :
426
449
return self ._query
427
450
self ._query = self .resolve_redisearch_query (self .expression )
451
+ if self .knn :
452
+ self ._query = (
453
+ self ._query
454
+ if self ._query .startswith ("(" ) or self ._query == "*"
455
+ else f"({ self ._query } )"
456
+ ) + f"=>[{ self .knn } ]"
428
457
return self ._query
429
458
459
+ @property
460
+ def query_params (self ):
461
+ params : List [Union [str , bytes ]] = []
462
+ if self .knn :
463
+ params += [attr for kv in self .knn .query_params .items () for attr in kv ]
464
+ return params
465
+
430
466
def validate_sort_fields (self , sort_fields : List [str ]):
431
467
for sort_field in sort_fields :
432
468
field_name = sort_field .lstrip ("-" )
469
+ if self .knn and field_name == self .knn .score_field :
470
+ continue
433
471
if field_name not in self .model .__fields__ :
434
472
raise QueryNotSupportedError (
435
473
f"You tried sort by { field_name } , but that field "
@@ -728,10 +766,27 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
728
766
return result
729
767
730
768
async def execute (self , exhaust_results = True , return_raw_result = False ):
731
- args = ["ft.search" , self .model .Meta .index_name , self .query , * self .pagination ]
769
+ args : List [Union [str , bytes ]] = [
770
+ "FT.SEARCH" ,
771
+ self .model .Meta .index_name ,
772
+ self .query ,
773
+ * self .pagination ,
774
+ ]
732
775
if self .sort_fields :
733
776
args += self .resolve_redisearch_sort_fields ()
734
777
778
+ if self .query_params :
779
+ args += ["PARAMS" , str (len (self .query_params ))] + self .query_params
780
+
781
+ if self .knn :
782
+ # Ensure DIALECT is at least 2
783
+ if "DIALECT" not in args :
784
+ args += ["DIALECT" , "2" ]
785
+ else :
786
+ i_dialect = args .index ("DIALECT" ) + 1
787
+ if int (args [i_dialect ]) < 2 :
788
+ args [i_dialect ] = "2"
789
+
735
790
if self .nocontent :
736
791
args .append ("NOCONTENT" )
737
792
@@ -917,11 +972,13 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
917
972
sortable = kwargs .pop ("sortable" , Undefined )
918
973
index = kwargs .pop ("index" , Undefined )
919
974
full_text_search = kwargs .pop ("full_text_search" , Undefined )
975
+ vector_options = kwargs .pop ("vector_options" , None )
920
976
super ().__init__ (default = default , ** kwargs )
921
977
self .primary_key = primary_key
922
978
self .sortable = sortable
923
979
self .index = index
924
980
self .full_text_search = full_text_search
981
+ self .vector_options = vector_options
925
982
926
983
927
984
class RelationshipInfo (Representation ):
@@ -935,6 +992,94 @@ def __init__(
935
992
self .link_model = link_model
936
993
937
994
995
+ @dataclasses .dataclass
996
+ class VectorFieldOptions :
997
+ class ALGORITHM (Enum ):
998
+ FLAT = "FLAT"
999
+ HNSW = "HNSW"
1000
+
1001
+ class TYPE (Enum ):
1002
+ FLOAT32 = "FLOAT32"
1003
+ FLOAT64 = "FLOAT64"
1004
+
1005
+ class DISTANCE_METRIC (Enum ):
1006
+ L2 = "L2"
1007
+ IP = "IP"
1008
+ COSINE = "COSINE"
1009
+
1010
+ algorithm : ALGORITHM
1011
+ type : TYPE
1012
+ dimension : int
1013
+ distance_metric : DISTANCE_METRIC
1014
+
1015
+ # Common optional parameters
1016
+ initial_cap : Optional [int ] = None
1017
+
1018
+ # Optional parameters for FLAT
1019
+ block_size : Optional [int ] = None
1020
+
1021
+ # Optional parameters for HNSW
1022
+ m : Optional [int ] = None
1023
+ ef_construction : Optional [int ] = None
1024
+ ef_runtime : Optional [int ] = None
1025
+ epsilon : Optional [float ] = None
1026
+
1027
+ @staticmethod
1028
+ def flat (
1029
+ type : TYPE ,
1030
+ dimension : int ,
1031
+ distance_metric : DISTANCE_METRIC ,
1032
+ initial_cap : Optional [int ] = None ,
1033
+ block_size : Optional [int ] = None ,
1034
+ ):
1035
+ return VectorFieldOptions (
1036
+ algorithm = VectorFieldOptions .ALGORITHM .FLAT ,
1037
+ type = type ,
1038
+ dimension = dimension ,
1039
+ distance_metric = distance_metric ,
1040
+ initial_cap = initial_cap ,
1041
+ block_size = block_size ,
1042
+ )
1043
+
1044
+ @staticmethod
1045
+ def hnsw (
1046
+ type : TYPE ,
1047
+ dimension : int ,
1048
+ distance_metric : DISTANCE_METRIC ,
1049
+ initial_cap : Optional [int ] = None ,
1050
+ m : Optional [int ] = None ,
1051
+ ef_construction : Optional [int ] = None ,
1052
+ ef_runtime : Optional [int ] = None ,
1053
+ epsilon : Optional [float ] = None ,
1054
+ ):
1055
+ return VectorFieldOptions (
1056
+ algorithm = VectorFieldOptions .ALGORITHM .HNSW ,
1057
+ type = type ,
1058
+ dimension = dimension ,
1059
+ distance_metric = distance_metric ,
1060
+ initial_cap = initial_cap ,
1061
+ m = m ,
1062
+ ef_construction = ef_construction ,
1063
+ ef_runtime = ef_runtime ,
1064
+ epsilon = epsilon ,
1065
+ )
1066
+
1067
+ @property
1068
+ def schema (self ):
1069
+ attr = []
1070
+ for k , v in vars (self ).items ():
1071
+ if k == "algorithm" or v is None :
1072
+ continue
1073
+ attr .extend (
1074
+ [
1075
+ k .upper () if k != "dimension" else "DIM" ,
1076
+ str (v ) if not isinstance (v , Enum ) else v .name ,
1077
+ ]
1078
+ )
1079
+
1080
+ return " " .join ([f"VECTOR { self .algorithm .name } { len (attr )} " ] + attr )
1081
+
1082
+
938
1083
def Field (
939
1084
default : Any = Undefined ,
940
1085
* ,
@@ -964,6 +1109,7 @@ def Field(
964
1109
sortable : Union [bool , UndefinedType ] = Undefined ,
965
1110
index : Union [bool , UndefinedType ] = Undefined ,
966
1111
full_text_search : Union [bool , UndefinedType ] = Undefined ,
1112
+ vector_options : Optional [VectorFieldOptions ] = None ,
967
1113
schema_extra : Optional [Dict [str , Any ]] = None ,
968
1114
) -> Any :
969
1115
current_schema_extra = schema_extra or {}
@@ -991,6 +1137,7 @@ def Field(
991
1137
sortable = sortable ,
992
1138
index = index ,
993
1139
full_text_search = full_text_search ,
1140
+ vector_options = vector_options ,
994
1141
** current_schema_extra ,
995
1142
)
996
1143
field_info ._validate ()
@@ -1083,6 +1230,10 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
1083
1230
new_class ._meta .primary_key = PrimaryKey (
1084
1231
name = field_name , field = field
1085
1232
)
1233
+ if field .field_info .vector_options :
1234
+ score_attr = f"_{ field_name } _score"
1235
+ setattr (new_class , score_attr , None )
1236
+ new_class .__annotations__ [score_attr ] = Union [float , None ]
1086
1237
1087
1238
if not getattr (new_class ._meta , "global_key_prefix" , None ):
1088
1239
new_class ._meta .global_key_prefix = getattr (
@@ -1160,16 +1311,16 @@ async def delete(
1160
1311
return await cls ._delete (db , cls .make_primary_key (pk ))
1161
1312
1162
1313
@classmethod
1163
- async def get (cls , pk : Any ) -> "RedisModel " :
1314
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1164
1315
raise NotImplementedError
1165
1316
1166
1317
async def update (self , ** field_values ):
1167
1318
"""Update this model instance with the specified key-value pairs."""
1168
1319
raise NotImplementedError
1169
1320
1170
1321
async def save (
1171
- self , pipeline : Optional [redis .client .Pipeline ] = None
1172
- ) -> "RedisModel " :
1322
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1323
+ ) -> "Model " :
1173
1324
raise NotImplementedError
1174
1325
1175
1326
async def expire (
@@ -1216,8 +1367,12 @@ def db(cls):
1216
1367
return cls ._meta .database
1217
1368
1218
1369
@classmethod
1219
- def find (cls , * expressions : Union [Any , Expression ]) -> FindQuery :
1220
- return FindQuery (expressions = expressions , model = cls )
1370
+ def find (
1371
+ cls ,
1372
+ * expressions : Union [Any , Expression ],
1373
+ knn : Optional [KNNExpression ] = None ,
1374
+ ) -> FindQuery :
1375
+ return FindQuery (expressions = expressions , knn = knn , model = cls )
1221
1376
1222
1377
@classmethod
1223
1378
def from_redis (cls , res : Any ):
@@ -1237,7 +1392,7 @@ def to_string(s):
1237
1392
for i in range (1 , len (res ), step ):
1238
1393
if res [i + offset ] is None :
1239
1394
continue
1240
- fields = dict (
1395
+ fields : Dict [ str , str ] = dict (
1241
1396
zip (
1242
1397
map (to_string , res [i + offset ][::2 ]),
1243
1398
map (to_string , res [i + offset ][1 ::2 ]),
@@ -1247,6 +1402,9 @@ def to_string(s):
1247
1402
if fields .get ("$" ):
1248
1403
json_fields = json .loads (fields .pop ("$" ))
1249
1404
doc = cls (** json_fields )
1405
+ for k , v in fields .items ():
1406
+ if k .startswith ("__" ) and k .endswith ("_score" ):
1407
+ setattr (doc , k [1 :], float (v ))
1250
1408
else :
1251
1409
doc = cls (** fields )
1252
1410
@@ -1266,11 +1424,11 @@ def get_annotations(cls):
1266
1424
1267
1425
@classmethod
1268
1426
async def add (
1269
- cls ,
1270
- models : Sequence ["RedisModel " ],
1427
+ cls : Type [ "Model" ] ,
1428
+ models : Sequence ["Model " ],
1271
1429
pipeline : Optional [redis .client .Pipeline ] = None ,
1272
1430
pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1273
- ) -> Sequence ["RedisModel " ]:
1431
+ ) -> Sequence ["Model " ]:
1274
1432
db = cls ._get_db (pipeline , bulk = True )
1275
1433
1276
1434
for model in models :
@@ -1345,8 +1503,8 @@ def __init_subclass__(cls, **kwargs):
1345
1503
)
1346
1504
1347
1505
async def save (
1348
- self , pipeline : Optional [redis .client .Pipeline ] = None
1349
- ) -> "HashModel " :
1506
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1507
+ ) -> "Model " :
1350
1508
self .check ()
1351
1509
db = self ._get_db (pipeline )
1352
1510
@@ -1368,7 +1526,7 @@ async def all_pks(cls): # type: ignore
1368
1526
)
1369
1527
1370
1528
@classmethod
1371
- async def get (cls , pk : Any ) -> "HashModel " :
1529
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1372
1530
document = await cls .db ().hgetall (cls .make_primary_key (pk ))
1373
1531
if not document :
1374
1532
raise NotFoundError
@@ -1474,7 +1632,13 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
1474
1632
embedded_cls = embedded_cls [0 ]
1475
1633
schema = cls .schema_for_type (name , embedded_cls , field_info )
1476
1634
elif any (issubclass (typ , t ) for t in NUMERIC_TYPES ):
1477
- schema = f"{ name } NUMERIC"
1635
+ vector_options : Optional [VectorFieldOptions ] = getattr (
1636
+ field_info , "vector_options" , None
1637
+ )
1638
+ if vector_options :
1639
+ schema = f"{ name } { vector_options .schema } "
1640
+ else :
1641
+ schema = f"{ name } NUMERIC"
1478
1642
elif issubclass (typ , str ):
1479
1643
if getattr (field_info , "full_text_search" , False ) is True :
1480
1644
schema = (
@@ -1513,8 +1677,8 @@ def __init__(self, *args, **kwargs):
1513
1677
super ().__init__ (* args , ** kwargs )
1514
1678
1515
1679
async def save (
1516
- self , pipeline : Optional [redis .client .Pipeline ] = None
1517
- ) -> "JsonModel " :
1680
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1681
+ ) -> "Model " :
1518
1682
self .check ()
1519
1683
db = self ._get_db (pipeline )
1520
1684
@@ -1559,7 +1723,7 @@ async def update(self, **field_values):
1559
1723
await self .save ()
1560
1724
1561
1725
@classmethod
1562
- async def get (cls , pk : Any ) -> "JsonModel " :
1726
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1563
1727
document = json .dumps (await cls .db ().json ().get (cls .make_key (pk )))
1564
1728
if document == "null" :
1565
1729
raise NotFoundError
@@ -1623,10 +1787,22 @@ def schema_for_type(
1623
1787
# Not a class, probably a type annotation
1624
1788
field_is_model = False
1625
1789
1790
+ vector_options : Optional [VectorFieldOptions ] = getattr (
1791
+ field_info , "vector_options" , None
1792
+ )
1793
+ try :
1794
+ is_vector = vector_options and any (
1795
+ issubclass (get_args (typ )[0 ], t ) for t in NUMERIC_TYPES
1796
+ )
1797
+ except IndexError :
1798
+ raise RedisModelError (
1799
+ f"Vector field '{ name } ' must be annotated as a container type"
1800
+ )
1801
+
1626
1802
# When we encounter a list or model field, we need to descend
1627
1803
# into the values of the list or the fields of the model to
1628
1804
# find any values marked as indexed.
1629
- if is_container_type :
1805
+ if is_container_type and not is_vector :
1630
1806
field_type = get_origin (typ )
1631
1807
embedded_cls = get_args (typ )
1632
1808
if not embedded_cls :
@@ -1689,7 +1865,9 @@ def schema_for_type(
1689
1865
)
1690
1866
1691
1867
# TODO: GEO field
1692
- if parent_is_container_type or parent_is_model_in_container :
1868
+ if is_vector and vector_options :
1869
+ schema = f"{ path } AS { index_field_name } { vector_options .schema } "
1870
+ elif parent_is_container_type or parent_is_model_in_container :
1693
1871
if typ is not str :
1694
1872
raise RedisModelError (
1695
1873
"In this Preview release, list and tuple fields can only "
0 commit comments