Skip to content

Commit deb88cc

Browse files
aeciocjrh
andauthored
Expose Tantivy's DisjunctionMaxQuery (#244)
Co-authored-by: Caleb Hattingh <[email protected]>
1 parent 7651d2b commit deb88cc

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

src/query.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use crate::{make_term, Schema};
22
use pyo3::{
3-
exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple,
3+
exceptions,
4+
prelude::*,
5+
types::{PyAny, PyFloat, PyString, PyTuple},
46
};
57
use tantivy as tv;
68

@@ -151,4 +153,29 @@ impl Query {
151153
inner: Box::new(inner),
152154
})
153155
}
156+
157+
/// Construct a Tantivy's DisjunctionMaxQuery
158+
#[staticmethod]
159+
pub(crate) fn disjunction_max_query(
160+
subqueries: Vec<Query>,
161+
tie_breaker: Option<&PyFloat>,
162+
) -> PyResult<Query> {
163+
let inner_queries: Vec<Box<dyn tv::query::Query>> = subqueries
164+
.iter()
165+
.map(|query| query.inner.box_clone())
166+
.collect();
167+
168+
let dismax_query = if let Some(tie_breaker) = tie_breaker {
169+
tv::query::DisjunctionMaxQuery::with_tie_breaker(
170+
inner_queries,
171+
tie_breaker.extract::<f32>()?,
172+
)
173+
} else {
174+
tv::query::DisjunctionMaxQuery::new(inner_queries)
175+
};
176+
177+
Ok(Query {
178+
inner: Box::new(dismax_query),
179+
})
180+
}
154181
}

tantivy/tantivy.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ class Query:
209209
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
210210
pass
211211

212+
@staticmethod
213+
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
214+
pass
215+
216+
212217
class Order(Enum):
213218
Asc = 1
214219
Desc = 2

tests/tantivy_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,4 +877,25 @@ def test_boolean_query(self, ram_index):
877877
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
878878
Query.boolean_query([
879879
(query1, Occur.Must),
880-
])
880+
])
881+
882+
def test_disjunction_max_query(self, ram_index):
883+
index = ram_index
884+
885+
# query1 should match the doc: "The Old Man and the Sea"
886+
query1 = Query.term_query(index.schema, "title", "sea")
887+
# query2 should matches the doc: "Of Mice and Men"
888+
query2 = Query.term_query(index.schema, "title", "mice")
889+
# the disjunction max query should match both docs.
890+
query = Query.disjunction_max_query([query1, query2])
891+
892+
result = index.searcher().search(query, 10)
893+
assert len(result.hits) == 2
894+
895+
# the disjunction max query should also take a tie_breaker parameter
896+
query = Query.disjunction_max_query([query1, query2], tie_breaker=0.5)
897+
result = index.searcher().search(query, 10)
898+
assert len(result.hits) == 2
899+
900+
with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
901+
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)

0 commit comments

Comments
 (0)