Skip to content

Commit c74990a

Browse files
authored
Expose boost query (#250)
1 parent ed7374c commit c74990a

File tree

3 files changed

+119
-10
lines changed

3 files changed

+119
-10
lines changed

src/query.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,13 @@ impl Query {
178178
inner: Box::new(dismax_query),
179179
})
180180
}
181+
182+
#[staticmethod]
183+
#[pyo3(signature = (query, boost))]
184+
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
185+
let inner = tv::query::BoostQuery::new(query.inner, boost);
186+
Ok(Query {
187+
inner: Box::new(inner),
188+
})
189+
}
181190
}

tantivy/tantivy.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ class Query:
212212
@staticmethod
213213
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
214214
pass
215+
216+
@staticmethod
217+
def boost_query(query: Query, boost: float) -> Query:
218+
pass
215219

216220

217221
class Order(Enum):

tests/tantivy_test.py

Lines changed: 106 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -827,52 +827,52 @@ def test_boolean_query(self, ram_index):
827827
(Occur.Must, query1),
828828
(Occur.Must, query2)
829829
])
830-
830+
831831
# no document should match both queries
832832
result = index.searcher().search(query, 10)
833833
assert len(result.hits) == 0
834-
834+
835835
query = Query.boolean_query([
836836
(Occur.Should, query1),
837837
(Occur.Should, query2)
838838
])
839-
839+
840840
# two documents should match, one for each query
841841
result = index.searcher().search(query, 10)
842842
assert len(result.hits) == 2
843-
843+
844844
titles = set()
845845
for _, doc_address in result.hits:
846846
titles.update(index.searcher().doc(doc_address)["title"])
847847
assert (
848848
"The Old Man and the Sea" in titles and
849849
"Of Mice and Men" in titles
850850
)
851-
851+
852852
query = Query.boolean_query([
853853
(Occur.MustNot, query1),
854854
(Occur.Must, query1)
855855
])
856-
856+
857857
# must not should take precedence over must
858858
result = index.searcher().search(query, 10)
859859
assert len(result.hits) == 0
860-
860+
861861
query = Query.boolean_query((
862862
(Occur.Should, query1),
863863
(Occur.Should, query2)
864864
))
865-
865+
866866
# the Vec signature should fit the tuple signature
867867
result = index.searcher().search(query, 10)
868868
assert len(result.hits) == 2
869-
869+
870870
# test invalid queries
871871
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
872872
Query.boolean_query([
873873
(Occur.Must, Occur.Must, query1),
874874
])
875-
875+
876876
# test swapping the order of the tuple
877877
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
878878
Query.boolean_query([
@@ -899,3 +899,99 @@ def test_disjunction_max_query(self, ram_index):
899899

900900
with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
901901
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)
902+
903+
904+
def test_boost_query(self, ram_index):
905+
index = ram_index
906+
query1 = Query.term_query(index.schema, "title", "sea")
907+
boosted_query = Query.boost_query(query1, 2.0)
908+
909+
# Normal boost query
910+
assert (
911+
repr(boosted_query)
912+
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2))"""
913+
)
914+
915+
query2 = Query.fuzzy_term_query(index.schema, "title", "ice")
916+
combined_query = Query.boolean_query([
917+
(Occur.Should, boosted_query),
918+
(Occur.Should, query2)
919+
])
920+
boosted_query = Query.boost_query(combined_query, 2.0)
921+
922+
# Boosted boolean query
923+
assert (
924+
repr(boosted_query)
925+
== """Query(Boost(query=BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2)), (Should, FuzzyTermQuery { term: Term(field=0, type=Str, "ice"), distance: 1, transposition_cost_one: true, prefix: false })] }, boost=2))"""
926+
)
927+
928+
boosted_query = Query.boost_query(query1, 0.1)
929+
930+
# Check for decimal boost values
931+
assert(
932+
repr(boosted_query)
933+
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))"""
934+
)
935+
936+
boosted_query = Query.boost_query(query1, 0.0)
937+
938+
# Check for zero boost values
939+
assert(
940+
repr(boosted_query)
941+
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))"""
942+
)
943+
result = index.searcher().search(boosted_query, 10)
944+
for _score, _ in result.hits:
945+
# the score should be 0.0
946+
assert _score == pytest.approx(0.0)
947+
948+
boosted_query = Query.boost_query(
949+
Query.boost_query(
950+
query1, 0.1
951+
), 0.1
952+
)
953+
954+
# Check for nested boost queries
955+
assert(
956+
repr(boosted_query)
957+
== """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))"""
958+
)
959+
result = index.searcher().search(boosted_query, 10)
960+
for _score, _ in result.hits:
961+
# the score should be very small, due to
962+
# the unknown score of BM25, we can only check for the relative difference
963+
assert _score == pytest.approx(0.01, rel = 1)
964+
965+
966+
boosted_query = Query.boost_query(
967+
query1, -0.1
968+
)
969+
970+
# Check for negative boost values
971+
assert(
972+
repr(boosted_query)
973+
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))"""
974+
)
975+
976+
result = index.searcher().search(boosted_query, 10)
977+
# Even with a negative boost, the query should still match the document
978+
assert len(result.hits) == 1
979+
titles = set()
980+
for _score, doc_address in result.hits:
981+
982+
# the score should be negative
983+
assert _score < 0
984+
titles.update(index.searcher().doc(doc_address)["title"])
985+
assert titles == {"The Old Man and the Sea"}
986+
987+
# wrong query type
988+
with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"):
989+
Query.boost_query(1, 0.1)
990+
991+
# wrong boost type
992+
with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"):
993+
Query.boost_query(query1, "0.1")
994+
995+
# no boost type error
996+
with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"):
997+
Query.boost_query(query1)

0 commit comments

Comments
 (0)