Skip to content

Commit a3abbee

Browse files
committed
Add schema validation to PyDict -> Document
1 parent a266f41 commit a3abbee

File tree

3 files changed

+249
-55
lines changed

3 files changed

+249
-55
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ serde_json = "1.0.91"
2222

2323
[dependencies.pyo3]
2424
version = "0.18.0"
25-
features = ["extension-module"]
25+
features = ["chrono", "extension-module"]

src/document.rs

Lines changed: 154 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ use pyo3::{
99
},
1010
};
1111

12-
use chrono::{offset::TimeZone, Utc};
12+
use chrono::{offset::TimeZone, DateTime, Utc};
1313

1414
use tantivy as tv;
1515

16-
use crate::{facet::Facet, to_pyerr};
16+
use crate::{facet::Facet, schema::Schema, to_pyerr};
1717
use serde_json::Value as JsonValue;
1818
use std::{
1919
collections::{BTreeMap, HashMap},
@@ -175,18 +175,7 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
175175
if let Ok(num) = any.extract::<f64>() {
176176
return Ok(Value::F64(num));
177177
}
178-
if let Ok(py_datetime) = any.downcast::<PyDateTime>() {
179-
let datetime = Utc
180-
.with_ymd_and_hms(
181-
py_datetime.get_year(),
182-
py_datetime.get_month().into(),
183-
py_datetime.get_day().into(),
184-
py_datetime.get_hour().into(),
185-
py_datetime.get_minute().into(),
186-
py_datetime.get_second().into(),
187-
)
188-
.single()
189-
.unwrap();
178+
if let Ok(datetime) = any.extract::<DateTime<Utc>>() {
190179
return Ok(Value::Date(tv::DateTime::from_timestamp_secs(
191180
datetime.timestamp(),
192181
)));
@@ -200,6 +189,60 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
200189
Err(to_pyerr(format!("Value unsupported {any:?}")))
201190
}
202191

192+
pub(crate) fn extract_value_for_type(
193+
any: &PyAny,
194+
tv_type: tv::schema::Type,
195+
field_name: &str,
196+
) -> PyResult<Value> {
197+
// Helper function to create `PyErr`s returned by this function.
198+
fn to_pyerr_for_type<'a, E: std::error::Error>(
199+
type_name: &'a str,
200+
field_name: &'a str,
201+
any: &'a PyAny,
202+
) -> impl Fn(E) -> PyErr + 'a {
203+
move |_| {
204+
to_pyerr(format!(
205+
"Expected {} type for field {}, got {:?}",
206+
type_name, field_name, any
207+
))
208+
}
209+
}
210+
211+
let value = match tv_type {
212+
tv::schema::Type::Str => Value::Str(
213+
any.extract::<String>()
214+
.map_err(to_pyerr_for_type("Str", field_name, any))?,
215+
),
216+
tv::schema::Type::U64 => Value::U64(
217+
any.extract::<u64>()
218+
.map_err(to_pyerr_for_type("U64", field_name, any))?,
219+
),
220+
tv::schema::Type::I64 => Value::I64(
221+
any.extract::<i64>()
222+
.map_err(to_pyerr_for_type("I64", field_name, any))?,
223+
),
224+
tv::schema::Type::F64 => Value::F64(
225+
any.extract::<f64>()
226+
.map_err(to_pyerr_for_type("F64", field_name, any))?,
227+
),
228+
tv::schema::Type::Date => {
229+
let datetime = any
230+
.extract::<DateTime<Utc>>()
231+
.map_err(to_pyerr_for_type("DateTime", field_name, any))?;
232+
233+
Value::Date(tv::DateTime::from_timestamp_secs(datetime.timestamp()))
234+
}
235+
tv::schema::Type::Facet => Value::Facet(
236+
any.extract::<Facet>()
237+
.map_err(to_pyerr_for_type("Facet", field_name, any))?
238+
.inner,
239+
),
240+
_ => return Err(to_pyerr(format!("Value unsupported {:?}", any))),
241+
};
242+
243+
Ok(value)
244+
}
245+
203246
fn extract_value_single_or_list(any: &PyAny) -> PyResult<Vec<Value>> {
204247
if let Ok(values) = any.downcast::<PyList>() {
205248
values.iter().map(extract_value).collect()
@@ -208,51 +251,119 @@ fn extract_value_single_or_list(any: &PyAny) -> PyResult<Vec<Value>> {
208251
}
209252
}
210253

254+
fn extract_value_single_or_list_for_type(
255+
any: &PyAny,
256+
field_type: &tv::schema::FieldType,
257+
field_name: &str,
258+
) -> PyResult<Vec<Value>> {
259+
// Check if a numeric fast field supports multivalues.
260+
if let Ok(values) = any.downcast::<PyList>() {
261+
values
262+
.iter()
263+
.map(|any| {
264+
extract_value_for_type(any, field_type.value_type(), field_name)
265+
})
266+
.collect::<PyResult<Vec<Value>>>()
267+
} else {
268+
Ok(vec![extract_value_for_type(
269+
any,
270+
field_type.value_type(),
271+
field_name,
272+
)?])
273+
}
274+
}
275+
276+
impl Document {
277+
fn extract_py_values_from_dict(
278+
py_dict: &PyDict,
279+
schema: Option<&Schema>,
280+
out_field_values: &mut BTreeMap<String, Vec<tv::schema::Value>>,
281+
) -> PyResult<()> {
282+
// TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable.
283+
// out_field_values.reserve(py_dict.len());
284+
285+
for key_value_any in py_dict.items() {
286+
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
287+
if key_value.len() != 2 {
288+
continue;
289+
}
290+
let key = key_value.get_item(0)?.extract::<String>()?;
291+
292+
let field_type = if let Some(schema) = schema {
293+
let field_type = schema
294+
.inner
295+
.get_field(key.as_str())
296+
.map(|field| {
297+
schema.inner.get_field_entry(field).field_type()
298+
})
299+
.ok();
300+
301+
if let Some(field_type) = field_type {
302+
// A field type was found, so validate it after the values are extracted.
303+
Some(field_type)
304+
} else {
305+
// The field does not exist in the schema, so skip over it.
306+
continue;
307+
}
308+
} else {
309+
// No schema was provided, so do not validate anything.
310+
None
311+
};
312+
313+
let value_list = if let Some(field_type) = field_type {
314+
extract_value_single_or_list_for_type(
315+
key_value.get_item(1)?,
316+
field_type,
317+
key.as_str(),
318+
)?
319+
} else {
320+
extract_value_single_or_list(key_value.get_item(1)?)?
321+
};
322+
323+
out_field_values.insert(key, value_list);
324+
}
325+
}
326+
327+
Ok(())
328+
}
329+
}
330+
211331
#[pymethods]
212332
impl Document {
213333
#[new]
214334
#[pyo3(signature = (**kwargs))]
215335
fn new(kwargs: Option<&PyDict>) -> PyResult<Self> {
216336
let mut document = Document::default();
217337
if let Some(field_dict) = kwargs {
218-
document.extend(field_dict)?;
338+
document.extend(field_dict, None)?;
219339
}
220340
Ok(document)
221341
}
222342

223-
fn extend(&mut self, py_dict: &PyDict) -> PyResult<()> {
224-
let mut field_values: BTreeMap<String, Vec<tv::schema::Value>> =
225-
BTreeMap::new();
226-
for key_value_any in py_dict.items() {
227-
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
228-
if key_value.len() != 2 {
229-
continue;
230-
}
231-
let key: String = key_value.get_item(0)?.extract()?;
232-
let value_list =
233-
extract_value_single_or_list(key_value.get_item(1)?)?;
234-
field_values.insert(key, value_list);
235-
}
236-
}
237-
self.field_values.extend(field_values.into_iter());
238-
Ok(())
343+
fn extend(
344+
&mut self,
345+
py_dict: &PyDict,
346+
schema: Option<&Schema>,
347+
) -> PyResult<()> {
348+
Document::extract_py_values_from_dict(
349+
py_dict,
350+
schema,
351+
&mut self.field_values,
352+
)
239353
}
240354

241355
#[staticmethod]
242-
fn from_dict(py_dict: &PyDict) -> PyResult<Document> {
356+
fn from_dict(
357+
py_dict: &PyDict,
358+
schema: Option<&Schema>,
359+
) -> PyResult<Document> {
243360
let mut field_values: BTreeMap<String, Vec<tv::schema::Value>> =
244361
BTreeMap::new();
245-
for key_value_any in py_dict.items() {
246-
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
247-
if key_value.len() != 2 {
248-
continue;
249-
}
250-
let key: String = key_value.get_item(0)?.extract()?;
251-
let value_list =
252-
extract_value_single_or_list(key_value.get_item(1)?)?;
253-
field_values.insert(key, value_list);
254-
}
255-
}
362+
Document::extract_py_values_from_dict(
363+
py_dict,
364+
schema,
365+
&mut field_values,
366+
)?;
256367
Ok(Document { field_values })
257368
}
258369

0 commit comments

Comments
 (0)