Skip to content

Commit d35fe09

Browse files
committed
Expose Tantivy's TermSetQuery
1 parent 7e57a00 commit d35fe09

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

src/query.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,26 @@ impl Query {
8989
})
9090
}
9191

92+
/// Construct a Tantivy's TermSetQuery
93+
#[staticmethod]
94+
#[pyo3(signature = (schema, field_name, field_values))]
95+
pub(crate) fn term_set_query(
96+
schema: &Schema,
97+
field_name: &str,
98+
field_values: Vec<&PyAny>,
99+
) -> PyResult<Query> {
100+
let terms = field_values
101+
.into_iter()
102+
.map(|field_value| {
103+
make_term(&schema.inner, field_name, &field_value)
104+
})
105+
.collect::<Result<Vec<_>, _>>()?;
106+
let inner = tv::query::TermSetQuery::new(terms);
107+
Ok(Query {
108+
inner: Box::new(inner),
109+
})
110+
}
111+
92112
/// Construct a Tantivy's AllQuery
93113
#[staticmethod]
94114
pub(crate) fn all_query() -> PyResult<Query> {

tantivy/tantivy.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ class Query:
197197
) -> Query:
198198
pass
199199

200+
@staticmethod
201+
def term_set_query(schema: Schema, field_name: str, field_values: Sequence[Any]) -> Query:
202+
pass
203+
200204
@staticmethod
201205
def all_query() -> Query:
202206
pass

tests/tantivy_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,35 @@ def test_term_query(self, ram_index):
765765
searched_doc = index.searcher().doc(doc_address)
766766
assert searched_doc["title"] == ["The Old Man and the Sea"]
767767

768+
def test_term_set_query(self, ram_index):
769+
index = ram_index
770+
771+
# Should match 1 document that contains both terms
772+
terms = ["old", "man"]
773+
query = Query.term_set_query(index.schema, "title", terms)
774+
result = index.searcher().search(query, 10)
775+
assert len(result.hits) == 1
776+
_, doc_address = result.hits[0]
777+
searched_doc = index.searcher().doc(doc_address)
778+
assert searched_doc["title"] == ["The Old Man and the Sea"]
779+
780+
# Should not match any document since the term does not exist in the index
781+
terms = ["a long term that does not exist in the index"]
782+
query = Query.term_set_query(index.schema, "title", terms)
783+
result = index.searcher().search(query, 10)
784+
assert len(result.hits) == 0
785+
786+
# Should not match any document when the terms list is empty
787+
terms = []
788+
query = Query.term_set_query(index.schema, "title", terms)
789+
result = index.searcher().search(query, 10)
790+
assert len(result.hits) == 0
791+
792+
# Should fail to create the query due to the invalid list object in the terms list
793+
with pytest.raises(ValueError, match = r"Can't create a term for Field `title` with value `\[\]`"):
794+
terms = ["old", [], "man"]
795+
query = Query.term_set_query(index.schema, "title", terms)
796+
768797
def test_all_query(self, ram_index):
769798
index = ram_index
770799
query = Query.all_query()

0 commit comments

Comments
 (0)