Skip to content

Commit 1d61b96

Browse files
authored
Expose the method of boolean query (#243)
1 parent 9fa82ef commit 1d61b96

File tree

4 files changed

+131
-6
lines changed

4 files changed

+131
-6
lines changed

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ mod snippet;
1515
use document::Document;
1616
use facet::Facet;
1717
use index::Index;
18-
use query::Query;
18+
use query::{Occur, Query};
1919
use schema::Schema;
2020
use schemabuilder::SchemaBuilder;
2121
use searcher::{DocAddress, Order, SearchResult, Searcher};
@@ -87,6 +87,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
8787
m.add_class::<Query>()?;
8888
m.add_class::<Snippet>()?;
8989
m.add_class::<SnippetGenerator>()?;
90+
m.add_class::<Occur>()?;
9091

9192
m.add_wrapped(wrap_pymodule!(query_parser_error))?;
9293

src/query.rs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,55 @@
11
use crate::{make_term, Schema};
2-
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString};
2+
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple};
33
use tantivy as tv;
44

5+
/// Custom Tuple struct to represent a pair of Occur and Query
6+
/// for the BooleanQuery
7+
struct OccurQueryPair(Occur, Query);
8+
9+
impl <'source> FromPyObject<'source> for OccurQueryPair {
10+
fn extract(ob: &'source PyAny) -> PyResult<Self> {
11+
let tuple = ob.downcast::<PyTuple>()?;
12+
let occur = tuple.get_item(0)?.extract()?;
13+
let query = tuple.get_item(1)?.extract()?;
14+
15+
Ok(OccurQueryPair(occur, query))
16+
}
17+
}
18+
19+
20+
/// Tantivy's Occur
21+
#[pyclass(frozen, module = "tantivy.tantivy")]
22+
#[derive(Clone)]
23+
pub enum Occur {
24+
Must,
25+
Should,
26+
MustNot,
27+
}
28+
29+
impl From<Occur> for tv::query::Occur {
30+
fn from(occur: Occur) -> tv::query::Occur {
31+
match occur {
32+
Occur::Must => tv::query::Occur::Must,
33+
Occur::Should => tv::query::Occur::Should,
34+
Occur::MustNot => tv::query::Occur::MustNot,
35+
}
36+
}
37+
}
38+
539
/// Tantivy's Query
640
#[pyclass(frozen, module = "tantivy.tantivy")]
741
pub(crate) struct Query {
842
pub(crate) inner: Box<dyn tv::query::Query>,
943
}
1044

45+
impl Clone for Query {
46+
fn clone(&self) -> Self {
47+
Query {
48+
inner: self.inner.box_clone(),
49+
}
50+
}
51+
}
52+
1153
impl Query {
1254
pub(crate) fn get(&self) -> &dyn tv::query::Query {
1355
&self.inner
@@ -91,4 +133,21 @@ impl Query {
91133
inner: Box::new(inner),
92134
})
93135
}
136+
137+
#[staticmethod]
138+
#[pyo3(signature = (subqueries))]
139+
pub(crate) fn boolean_query(
140+
subqueries: Vec<(Occur, Query)>
141+
) -> PyResult<Query> {
142+
let dyn_subqueries = subqueries
143+
.into_iter()
144+
.map(|(occur, query)| (occur.into(), query.inner.box_clone()))
145+
.collect::<Vec<_>>();
146+
147+
let inner = tv::query::BooleanQuery::from(dyn_subqueries);
148+
149+
Ok(Query {
150+
inner: Box::new(inner),
151+
})
152+
}
94153
}

tantivy/tantivy.pyi

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime
22
from enum import Enum
3-
from typing import Any, Optional
3+
from typing import Any, Optional, Sequence
44

55

66
class Schema:
@@ -187,6 +187,10 @@ class Document:
187187
def get_all(self, field_name: str) -> list[Any]:
188188
pass
189189

190+
class Occur(Enum):
191+
Must = 1
192+
Should = 2
193+
MustNot = 3
190194

191195
class Query:
192196
@staticmethod
@@ -200,7 +204,10 @@ class Query:
200204
@staticmethod
201205
def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query:
202206
pass
203-
207+
208+
@staticmethod
209+
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
210+
pass
204211

205212
class Order(Enum):
206213
Asc = 1

tests/tantivy_test.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import tantivy
1010
from conftest import schema, schema_numeric_fields
11-
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query
11+
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur
1212

1313

1414
class TestClass(object):
@@ -819,4 +819,62 @@ def test_fuzzy_term_query(self, ram_index):
819819
titles.update(index.searcher().doc(doc_address)["title"])
820820
assert titles == {"Frankenstein", "The Modern Prometheus"}
821821

822-
822+
def test_boolean_query(self, ram_index):
823+
index = ram_index
824+
query1 = Query.fuzzy_term_query(index.schema, "title", "ice")
825+
query2 = Query.fuzzy_term_query(index.schema, "title", "mna")
826+
query = Query.boolean_query([
827+
(Occur.Must, query1),
828+
(Occur.Must, query2)
829+
])
830+
831+
# no document should match both queries
832+
result = index.searcher().search(query, 10)
833+
assert len(result.hits) == 0
834+
835+
query = Query.boolean_query([
836+
(Occur.Should, query1),
837+
(Occur.Should, query2)
838+
])
839+
840+
# two documents should match, one for each query
841+
result = index.searcher().search(query, 10)
842+
assert len(result.hits) == 2
843+
844+
titles = set()
845+
for _, doc_address in result.hits:
846+
titles.update(index.searcher().doc(doc_address)["title"])
847+
assert (
848+
"The Old Man and the Sea" in titles and
849+
"Of Mice and Men" in titles
850+
)
851+
852+
query = Query.boolean_query([
853+
(Occur.MustNot, query1),
854+
(Occur.Must, query1)
855+
])
856+
857+
# must not should take precedence over must
858+
result = index.searcher().search(query, 10)
859+
assert len(result.hits) == 0
860+
861+
query = Query.boolean_query((
862+
(Occur.Should, query1),
863+
(Occur.Should, query2)
864+
))
865+
866+
# the Vec signature should fit the tuple signature
867+
result = index.searcher().search(query, 10)
868+
assert len(result.hits) == 2
869+
870+
# test invalid queries
871+
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
872+
Query.boolean_query([
873+
(Occur.Must, Occur.Must, query1),
874+
])
875+
876+
# test swapping the order of the tuple
877+
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
878+
Query.boolean_query([
879+
(query1, Occur.Must),
880+
])

0 commit comments

Comments
 (0)