@@ -827,52 +827,52 @@ def test_boolean_query(self, ram_index):
827
827
(Occur .Must , query1 ),
828
828
(Occur .Must , query2 )
829
829
])
830
-
830
+
831
831
# no document should match both queries
832
832
result = index .searcher ().search (query , 10 )
833
833
assert len (result .hits ) == 0
834
-
834
+
835
835
query = Query .boolean_query ([
836
836
(Occur .Should , query1 ),
837
837
(Occur .Should , query2 )
838
838
])
839
-
839
+
840
840
# two documents should match, one for each query
841
841
result = index .searcher ().search (query , 10 )
842
842
assert len (result .hits ) == 2
843
-
843
+
844
844
titles = set ()
845
845
for _ , doc_address in result .hits :
846
846
titles .update (index .searcher ().doc (doc_address )["title" ])
847
847
assert (
848
848
"The Old Man and the Sea" in titles and
849
849
"Of Mice and Men" in titles
850
850
)
851
-
851
+
852
852
query = Query .boolean_query ([
853
853
(Occur .MustNot , query1 ),
854
854
(Occur .Must , query1 )
855
855
])
856
-
856
+
857
857
# must not should take precedence over must
858
858
result = index .searcher ().search (query , 10 )
859
859
assert len (result .hits ) == 0
860
-
860
+
861
861
query = Query .boolean_query ((
862
862
(Occur .Should , query1 ),
863
863
(Occur .Should , query2 )
864
864
))
865
-
865
+
866
866
# the Vec signature should fit the tuple signature
867
867
result = index .searcher ().search (query , 10 )
868
868
assert len (result .hits ) == 2
869
-
869
+
870
870
# test invalid queries
871
871
with pytest .raises (ValueError , match = "expected tuple of length 2, but got tuple of length 3" ):
872
872
Query .boolean_query ([
873
873
(Occur .Must , Occur .Must , query1 ),
874
874
])
875
-
875
+
876
876
# test swapping the order of the tuple
877
877
with pytest .raises (TypeError , match = r"'Query' object cannot be converted to 'Occur'" ):
878
878
Query .boolean_query ([
@@ -899,3 +899,99 @@ def test_disjunction_max_query(self, ram_index):
899
899
900
900
with pytest .raises (TypeError , match = r"'str' object cannot be converted to 'Query'" ):
901
901
query = Query .disjunction_max_query ([query1 , "not a query" ], tie_breaker = 0.5 )
902
+
903
+
904
+ def test_boost_query (self , ram_index ):
905
+ index = ram_index
906
+ query1 = Query .term_query (index .schema , "title" , "sea" )
907
+ boosted_query = Query .boost_query (query1 , 2.0 )
908
+
909
+ # Normal boost query
910
+ assert (
911
+ repr (boosted_query )
912
+ == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2))"""
913
+ )
914
+
915
+ query2 = Query .fuzzy_term_query (index .schema , "title" , "ice" )
916
+ combined_query = Query .boolean_query ([
917
+ (Occur .Should , boosted_query ),
918
+ (Occur .Should , query2 )
919
+ ])
920
+ boosted_query = Query .boost_query (combined_query , 2.0 )
921
+
922
+ # Boosted boolean query
923
+ assert (
924
+ repr (boosted_query )
925
+ == """Query(Boost(query=BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2)), (Should, FuzzyTermQuery { term: Term(field=0, type=Str, "ice"), distance: 1, transposition_cost_one: true, prefix: false })] }, boost=2))"""
926
+ )
927
+
928
+ boosted_query = Query .boost_query (query1 , 0.1 )
929
+
930
+ # Check for decimal boost values
931
+ assert (
932
+ repr (boosted_query )
933
+ == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))"""
934
+ )
935
+
936
+ boosted_query = Query .boost_query (query1 , 0.0 )
937
+
938
+ # Check for zero boost values
939
+ assert (
940
+ repr (boosted_query )
941
+ == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))"""
942
+ )
943
+ result = index .searcher ().search (boosted_query , 10 )
944
+ for _score , _ in result .hits :
945
+ # the score should be 0.0
946
+ assert _score == pytest .approx (0.0 )
947
+
948
+ boosted_query = Query .boost_query (
949
+ Query .boost_query (
950
+ query1 , 0.1
951
+ ), 0.1
952
+ )
953
+
954
+ # Check for nested boost queries
955
+ assert (
956
+ repr (boosted_query )
957
+ == """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))"""
958
+ )
959
+ result = index .searcher ().search (boosted_query , 10 )
960
+ for _score , _ in result .hits :
961
+ # the score should be very small, due to
962
+ # the unknown score of BM25, we can only check for the relative difference
963
+ assert _score == pytest .approx (0.01 , rel = 1 )
964
+
965
+
966
+ boosted_query = Query .boost_query (
967
+ query1 , - 0.1
968
+ )
969
+
970
+ # Check for negative boost values
971
+ assert (
972
+ repr (boosted_query )
973
+ == """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))"""
974
+ )
975
+
976
+ result = index .searcher ().search (boosted_query , 10 )
977
+ # Even with a negative boost, the query should still match the document
978
+ assert len (result .hits ) == 1
979
+ titles = set ()
980
+ for _score , doc_address in result .hits :
981
+
982
+ # the score should be negative
983
+ assert _score < 0
984
+ titles .update (index .searcher ().doc (doc_address )["title" ])
985
+ assert titles == {"The Old Man and the Sea" }
986
+
987
+ # wrong query type
988
+ with pytest .raises (TypeError , match = r"'int' object cannot be converted to 'Query'" ):
989
+ Query .boost_query (1 , 0.1 )
990
+
991
+ # wrong boost type
992
+ with pytest .raises (TypeError , match = r"argument 'boost': must be real number, not str" ):
993
+ Query .boost_query (query1 , "0.1" )
994
+
995
+ # no boost type error
996
+ with pytest .raises (TypeError , match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'" ):
997
+ Query .boost_query (query1 )
0 commit comments