Skip to content

Expose IndexWriter::wait_merging_threads() #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,36 @@ const RELOAD_POLICY: &str = "commit";
/// on the index object.
#[pyclass]
pub(crate) struct IndexWriter {
inner_index_writer: tv::IndexWriter,
inner_index_writer: Option<tv::IndexWriter>,
schema: tv::schema::Schema,
}

impl IndexWriter {
fn inner(&self) -> PyResult<&tv::IndexWriter> {
self.inner_index_writer.as_ref().ok_or_else(|| {
exceptions::PyValueError::new_err(
"IndexWriter was consumed and no longer in a valid state",
)
})
}

fn inner_mut(&mut self) -> PyResult<&mut tv::IndexWriter> {
self.inner_index_writer.as_mut().ok_or_else(|| {
exceptions::PyValueError::new_err(
"IndexWriter was consumed and no longer in a valid state",
)
})
}

fn take_inner(&mut self) -> PyResult<tv::IndexWriter> {
self.inner_index_writer.take().ok_or_else(|| {
exceptions::PyValueError::new_err(
"IndexWriter was consumed and no longer in a valid state",
)
})
}
}

#[pymethods]
impl IndexWriter {
/// Add a document to the index.
Expand All @@ -45,7 +71,7 @@ impl IndexWriter {
pub fn add_document(&mut self, doc: &Document) -> PyResult<u64> {
let named_doc = NamedFieldDocument(doc.field_values.clone());
let doc = self.schema.convert_named_doc(named_doc).map_err(to_pyerr)?;
self.inner_index_writer.add_document(doc).map_err(to_pyerr)
self.inner()?.add_document(doc).map_err(to_pyerr)
}

/// Helper for the `add_document` method, but passing a json string.
Expand All @@ -58,7 +84,7 @@ impl IndexWriter {
/// since the creation of the index.
pub fn add_json(&mut self, json: &str) -> PyResult<u64> {
let doc = self.schema.parse_document(json).map_err(to_pyerr)?;
let opstamp = self.inner_index_writer.add_document(doc);
let opstamp = self.inner()?.add_document(doc);
opstamp.map_err(to_pyerr)
}

Expand All @@ -72,7 +98,7 @@ impl IndexWriter {
///
/// Returns the `opstamp` of the last document that made it in the commit.
fn commit(&mut self) -> PyResult<u64> {
self.inner_index_writer.commit().map_err(to_pyerr)
self.inner_mut()?.commit().map_err(to_pyerr)
}

/// Rollback to the last commit
Expand All @@ -81,14 +107,13 @@ impl IndexWriter {
/// commit. After calling rollback, the index is in the same state as it
/// was after the last commit.
fn rollback(&mut self) -> PyResult<u64> {
self.inner_index_writer.rollback().map_err(to_pyerr)
self.inner_mut()?.rollback().map_err(to_pyerr)
}

/// Detect and removes the files that are not used by the index anymore.
fn garbage_collect_files(&mut self) -> PyResult<()> {
use futures::executor::block_on;
block_on(self.inner_index_writer.garbage_collect_files())
.map_err(to_pyerr)?;
block_on(self.inner()?.garbage_collect_files()).map_err(to_pyerr)?;
Ok(())
}

Expand All @@ -100,8 +125,8 @@ impl IndexWriter {
/// This is also the opstamp of the commit that is currently available
/// for searchers.
#[getter]
fn commit_opstamp(&self) -> u64 {
self.inner_index_writer.commit_opstamp()
fn commit_opstamp(&self) -> PyResult<u64> {
Ok(self.inner()?.commit_opstamp())
}

/// Delete all documents containing a given term.
Expand Down Expand Up @@ -144,7 +169,16 @@ impl IndexWriter {
Value::Bool(b) => Term::from_field_bool(field, b),
Value::IpAddr(i) => Term::from_field_ip_addr(field, i)
};
Ok(self.inner_index_writer.delete_term(term))
Ok(self.inner()?.delete_term(term))
}

/// If there are some merging threads, blocks until they all finish
/// their work and then drop the `IndexWriter`.
///
/// This will consume the `IndexWriter`. Further accesses to the
/// object will result in an error.
pub fn wait_merging_threads(&mut self) -> PyResult<()> {
self.take_inner()?.wait_merging_threads().map_err(to_pyerr)
}
}

Expand Down Expand Up @@ -229,7 +263,7 @@ impl Index {
.map_err(to_pyerr)?;
let schema = self.index.schema();
Ok(IndexWriter {
inner_index_writer: writer,
inner_index_writer: Some(writer),
schema,
})
}
Expand Down
6 changes: 6 additions & 0 deletions src/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ impl Searcher {
self.inner.num_docs()
}

/// Returns the number of segments in the index.
#[getter]
fn num_segments(&self) -> usize {
self.inner.segment_readers().len()
}

/// Fetches a document from Tantivy's store given a DocAddress.
///
/// Args:
Expand Down
41 changes: 41 additions & 0 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,47 @@ def test_order_by_search_without_fast_field(self):
result = searcher.search(query, 10, order_by_field="order")
assert len(result.hits) == 0

def test_with_merges(self):
# This test is taken from tantivy's test suite:
# https://github.com/quickwit-oss/tantivy/blob/42acd334f49d5ff7e4fe846b5c12198f24409b50/src/indexer/index_writer.rs#L1130
schema = SchemaBuilder().add_text_field("text", stored=True).build()

index = Index(schema)
index.config_reader(reload_policy="Manual")

writer = index.writer()

for _ in range(100):
doc = Document()
doc.add_text("text", "a")

writer.add_document(doc)

writer.commit()

for _ in range(100):
doc = Document()
doc.add_text("text", "a")

writer.add_document(doc)

# This should create 8 segments and trigger a merge.
writer.commit()
writer.wait_merging_threads()

# Accessing the writer again should result in an error.
with pytest.raises(ValueError):
writer.wait_merging_threads()

index.reload()

query = index.parse_query("a")
searcher = index.searcher()
result = searcher.search(query, limit=500, count=True)
assert result.count == 200

assert searcher.num_segments < 8

def test_doc_from_dict_schema_validation(self):
schema = (
SchemaBuilder()
Expand Down