Skip to content

Commit 9fafdf2

Browse files
authored
Expose Tantivy's MoreLikeThisQuery (#257)
1 parent 03b1c89 commit 9fafdf2

File tree

4 files changed

+99
-4
lines changed

4 files changed

+99
-4
lines changed

src/query.rs

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::{get_field, make_term, to_pyerr, Schema};
1+
use crate::{get_field, make_term, to_pyerr, DocAddress, Schema};
22
use pyo3::{
33
exceptions,
44
prelude::*,
@@ -100,7 +100,7 @@ impl Query {
100100
let terms = field_values
101101
.into_iter()
102102
.map(|field_value| {
103-
make_term(&schema.inner, field_name, &field_value)
103+
make_term(&schema.inner, field_name, field_value)
104104
})
105105
.collect::<Result<Vec<_>, _>>()?;
106106
let inner = tv::query::TermSetQuery::new(terms);
@@ -138,7 +138,7 @@ impl Query {
138138
transposition_cost_one: bool,
139139
prefix: bool,
140140
) -> PyResult<Query> {
141-
let term = make_term(&schema.inner, field_name, &text)?;
141+
let term = make_term(&schema.inner, field_name, text)?;
142142
let inner = if prefix {
143143
tv::query::FuzzyTermQuery::new_prefix(
144144
term,
@@ -272,6 +272,50 @@ impl Query {
272272
}
273273
}
274274

275+
#[staticmethod]
276+
#[pyo3(signature = (doc_address, min_doc_frequency = Some(5), max_doc_frequency = None, min_term_frequency = Some(2), max_query_terms = Some(25), min_word_length = None, max_word_length = None, boost_factor = Some(1.0), stop_words = vec![]))]
277+
#[allow(clippy::too_many_arguments)]
278+
pub(crate) fn more_like_this_query(
279+
doc_address: &DocAddress,
280+
min_doc_frequency: Option<u64>,
281+
max_doc_frequency: Option<u64>,
282+
min_term_frequency: Option<usize>,
283+
max_query_terms: Option<usize>,
284+
min_word_length: Option<usize>,
285+
max_word_length: Option<usize>,
286+
boost_factor: Option<f32>,
287+
stop_words: Vec<String>,
288+
) -> PyResult<Query> {
289+
let mut builder = tv::query::MoreLikeThisQuery::builder();
290+
if let Some(value) = min_doc_frequency {
291+
builder = builder.with_min_doc_frequency(value);
292+
}
293+
if let Some(value) = max_doc_frequency {
294+
builder = builder.with_max_doc_frequency(value);
295+
}
296+
if let Some(value) = min_term_frequency {
297+
builder = builder.with_min_term_frequency(value);
298+
}
299+
if let Some(value) = max_query_terms {
300+
builder = builder.with_max_query_terms(value);
301+
}
302+
if let Some(value) = min_word_length {
303+
builder = builder.with_min_word_length(value);
304+
}
305+
if let Some(value) = max_word_length {
306+
builder = builder.with_max_word_length(value);
307+
}
308+
if let Some(value) = boost_factor {
309+
builder = builder.with_boost_factor(value);
310+
}
311+
builder = builder.with_stop_words(stop_words);
312+
313+
let inner = builder.with_document(tv::DocAddress::from(doc_address));
314+
Ok(Query {
315+
inner: Box::new(inner),
316+
})
317+
}
318+
275319
/// Construct a Tantivy's ConstScoreQuery
276320
#[staticmethod]
277321
#[pyo3(signature = (query, score))]

src/searcher.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ impl Searcher {
153153
///
154154
/// Raises a ValueError if there was an error with the search.
155155
#[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc))]
156+
#[allow(clippy::too_many_arguments)]
156157
fn search(
157158
&self,
158159
py: Python,

tantivy/tantivy.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,20 @@ class Query:
227227
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
228228
pass
229229

230+
@staticmethod
231+
def more_like_this_query(
232+
doc_address: DocAddress,
233+
min_doc_frequency: Optional[int] = 5,
234+
max_doc_frequency: Optional[int] = None,
235+
min_term_frequency: Optional[int] = 2,
236+
max_query_terms: Optional[int] = 25,
237+
min_word_length: Optional[int] = None,
238+
max_word_length: Optional[int] = None,
239+
boost_factor: Optional[float] = 1.0,
240+
stop_words: list[str] = []
241+
) -> Query:
242+
pass
243+
230244
@staticmethod
231245
def const_score_query(query: Query, score: float) -> Query:
232246
pass

tests/tantivy_test.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,42 @@ def test_regex_query(self, ram_index):
10901090
):
10911091
Query.regex_query(index.schema, "body", "fish(")
10921092

1093+
def test_more_like_this_query(self, ram_index):
1094+
index = ram_index
1095+
1096+
# first, search the target doc
1097+
query = Query.term_query(index.schema, "title", "man")
1098+
result = index.searcher().search(query, 1)
1099+
_, doc_address = result.hits[0]
1100+
searched_doc = index.searcher().doc(doc_address)
1101+
assert searched_doc["title"] == ["The Old Man and the Sea"]
1102+
1103+
# construct the default MLT Query
1104+
mlt_query = Query.more_like_this_query(doc_address)
1105+
assert (
1106+
repr(mlt_query)
1107+
== "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(5), max_doc_frequency: None, min_term_frequency: Some(2), max_query_terms: Some(25), min_word_length: None, max_word_length: None, boost_factor: Some(1.0), stop_words: [] }, target: DocumentAdress(DocAddress { segment_ord: 0, doc_id: 0 }) })"
1108+
)
1109+
result = index.searcher().search(mlt_query, 10)
1110+
assert len(result.hits) == 0
1111+
1112+
# construct a fine-tuned MLT Query
1113+
mlt_query = Query.more_like_this_query(
1114+
doc_address,
1115+
min_doc_frequency=2,
1116+
max_doc_frequency=10,
1117+
min_term_frequency=1,
1118+
max_query_terms=10,
1119+
min_word_length=2,
1120+
max_word_length=20,
1121+
boost_factor=2.0,
1122+
stop_words=["fish"])
1123+
assert (
1124+
repr(mlt_query)
1125+
== "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(2), max_doc_frequency: Some(10), min_term_frequency: Some(1), max_query_terms: Some(10), min_word_length: Some(2), max_word_length: Some(20), boost_factor: Some(2.0), stop_words: [\"fish\"] }, target: DocumentAdress(DocAddress { segment_ord: 0, doc_id: 0 }) })"
1126+
)
1127+
result = index.searcher().search(mlt_query, 10)
1128+
assert len(result.hits) > 0
10931129
def test_const_score_query(self, ram_index):
10941130
index = ram_index
10951131

@@ -1119,4 +1155,4 @@ def test_const_score_query(self, ram_index):
11191155
# wrong score type
11201156
with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"):
11211157
Query.const_score_query(query, "0.1")
1122-
1158+

0 commit comments

Comments
 (0)