Skip to content

Commit 2b1439c

Browse files
authored
Merge pull request #65 from Sidhant29/add_float_support
Add Float Support
2 parents 164adc8 + 520159f commit 2b1439c

File tree

3 files changed

+108
-4
lines changed

3 files changed

+108
-4
lines changed

src/document.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,15 @@ impl Document {
300300
add_value(self, field_name, value);
301301
}
302302

303+
/// Add a float value to the document.
304+
///
305+
/// Args:
306+
/// field_name (str): The field name for which we are adding the value.
307+
/// value (f64): The float that will be added to the document.
308+
fn add_float(&mut self, field_name: String, value: f64) {
309+
add_value(self, field_name, value);
310+
}
311+
303312
/// Add a date value to the document.
304313
///
305314
/// Args:

src/schemabuilder.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ impl SchemaBuilder {
120120
) -> PyResult<Self> {
121121
let builder = &mut self.builder;
122122

123-
let opts = SchemaBuilder::build_int_option(stored, indexed, fast)?;
123+
let opts = SchemaBuilder::build_numeric_option(stored, indexed, fast)?;
124124

125125
if let Some(builder) = builder.write().unwrap().as_mut() {
126126
builder.add_i64_field(name, opts);
@@ -132,6 +132,28 @@ impl SchemaBuilder {
132132
Ok(self.clone())
133133
}
134134

135+
#[pyo3(signature = (name, stored = false, indexed = false, fast = None))]
136+
fn add_float_field(
137+
&mut self,
138+
name: &str,
139+
stored: bool,
140+
indexed: bool,
141+
fast: Option<&str>,
142+
) -> PyResult<Self> {
143+
let builder = &mut self.builder;
144+
145+
let opts = SchemaBuilder::build_numeric_option(stored, indexed, fast)?;
146+
147+
if let Some(builder) = builder.write().unwrap().as_mut() {
148+
builder.add_f64_field(name, opts);
149+
} else {
150+
return Err(exceptions::PyValueError::new_err(
151+
"Schema builder object isn't valid anymore.",
152+
));
153+
}
154+
Ok(self.clone())
155+
}
156+
135157
/// Add a new unsigned integer field to the schema.
136158
///
137159
/// Args:
@@ -162,7 +184,7 @@ impl SchemaBuilder {
162184
) -> PyResult<Self> {
163185
let builder = &mut self.builder;
164186

165-
let opts = SchemaBuilder::build_int_option(stored, indexed, fast)?;
187+
let opts = SchemaBuilder::build_numeric_option(stored, indexed, fast)?;
166188

167189
if let Some(builder) = builder.write().unwrap().as_mut() {
168190
builder.add_u64_field(name, opts);
@@ -343,7 +365,7 @@ impl SchemaBuilder {
343365
}
344366

345367
impl SchemaBuilder {
346-
fn build_int_option(
368+
fn build_numeric_option(
347369
stored: bool,
348370
indexed: bool,
349371
fast: Option<&str>,

tests/tantivy_test.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ def schema():
1212
.build()
1313
)
1414

15+
def schema_numeric_fields():
16+
return (
17+
SchemaBuilder()
18+
.add_integer_field("id", stored=True, indexed=True)
19+
.add_float_field("rating", stored=True, indexed=True)
20+
.add_text_field("body", stored=True)
21+
.build()
22+
)
1523

1624
def create_index(dir=None):
1725
# assume all tests will use the same documents for now
@@ -66,6 +74,46 @@ def create_index(dir=None):
6674
index.reload()
6775
return index
6876

77+
def create_index_with_numeric_fields(dir=None):
78+
index = Index(schema_numeric_fields(), dir)
79+
writer = index.writer()
80+
81+
doc = Document()
82+
doc.add_integer("id", 1)
83+
doc.add_float("rating", 3.5)
84+
doc.add_text(
85+
"body",
86+
(
87+
"He was an old man who fished alone in a skiff in"
88+
"the Gulf Stream and he had gone eighty-four days "
89+
"now without taking a fish."
90+
),
91+
)
92+
writer.add_document(doc)
93+
doc = Document.from_dict(
94+
{
95+
"id": 2,
96+
"rating": 4.5,
97+
"body": (
98+
"A few miles south of Soledad, the Salinas River drops "
99+
"in close to the hillside bank and runs deep and "
100+
"green. The water is warm too, for it has slipped "
101+
"twinkling over the yellow sands in the sunlight "
102+
"before reaching the narrow pool. On one side of the "
103+
"river the golden foothill slopes curve up to the "
104+
"strong and rocky Gabilan Mountains, but on the valley "
105+
"side the water is lined with trees—willows fresh and "
106+
"green with every spring, carrying in their lower leaf "
107+
"junctures the debris of the winter’s flooding; and "
108+
"sycamores with mottled, white, recumbent limbs and "
109+
"branches that arch over the pool"
110+
),
111+
}
112+
)
113+
writer.add_document(doc)
114+
writer.commit()
115+
index.reload()
116+
return index
69117

70118
def spanish_schema():
71119
return (
@@ -127,6 +175,11 @@ def ram_index():
127175
return create_index()
128176

129177

178+
@pytest.fixture(scope="class")
179+
def ram_index_numeric_fields():
180+
return create_index_with_numeric_fields()
181+
182+
130183
@pytest.fixture(scope="class")
131184
def spanish_index():
132185
return create_spanish_index()
@@ -185,6 +238,25 @@ def test_and_query(self, ram_index):
185238

186239
assert len(result.hits) == 1
187240

241+
def test_and_query_numeric_fields(self, ram_index_numeric_fields):
242+
index = ram_index_numeric_fields
243+
searcher = index.searcher()
244+
245+
# 1 result
246+
float_query = index.parse_query("3.5", ["rating"])
247+
result = searcher.search(float_query)
248+
assert len(result.hits) == 1
249+
assert searcher.doc(result.hits[0][1])['rating'][0] == 3.5
250+
251+
integer_query = index.parse_query("1", ["id"])
252+
result = searcher.search(integer_query)
253+
assert len(result.hits) == 1
254+
255+
# 0 result
256+
integer_query = index.parse_query("10", ["id"])
257+
result = searcher.search(integer_query)
258+
assert len(result.hits) == 0
259+
188260
def test_and_query_parser_default_fields(self, ram_index):
189261
query = ram_index.parse_query("winter", default_field_names=["title"])
190262
assert repr(query) == """Query(TermQuery(Term(type=Str, field=0, "winter")))"""
@@ -344,8 +416,9 @@ def test_create_readers(self):
344416

345417

346418
class TestSearcher(object):
347-
def test_searcher_repr(self, ram_index):
419+
def test_searcher_repr(self, ram_index, ram_index_numeric_fields):
348420
assert repr(ram_index.searcher()) == "Searcher(num_docs=3, num_segments=1)"
421+
assert repr(ram_index_numeric_fields.searcher()) == "Searcher(num_docs=2, num_segments=1)"
349422

350423

351424
class TestDocument(object):

0 commit comments

Comments
 (0)