Skip to content

Commit 27caa50

Browse files
committed
Add schema validation to PyDict -> Document
1 parent 468166f commit 27caa50

File tree

2 files changed

+273
-41
lines changed

2 files changed

+273
-41
lines changed

src/document.rs

Lines changed: 179 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use chrono::{offset::TimeZone, Utc};
1313

1414
use tantivy as tv;
1515

16-
use crate::{facet::Facet, to_pyerr};
16+
use crate::{facet::Facet, schema::Schema, to_pyerr};
1717
use serde_json::Value as JsonValue;
1818
use std::{
1919
collections::{BTreeMap, HashMap},
@@ -200,6 +200,71 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
200200
Err(to_pyerr(format!("Value unsupported {any:?}")))
201201
}
202202

203+
pub(crate) fn extract_value_for_type(
204+
any: &PyAny,
205+
tv_type: tv::schema::Type,
206+
field_name: &str,
207+
) -> PyResult<Value> {
208+
// Helper function to create `PyErr`s returned by this function.
209+
fn to_pyerr_for_type<'a, E: std::error::Error>(
210+
type_name: &'a str,
211+
field_name: &'a str,
212+
any: &'a PyAny,
213+
) -> impl Fn(E) -> PyErr + 'a {
214+
move |_| {
215+
to_pyerr(format!(
216+
"Expected {} type for field {}, got {:?}",
217+
type_name, field_name, any
218+
))
219+
}
220+
}
221+
222+
let value = match tv_type {
223+
tv::schema::Type::Str => Value::Str(
224+
any.extract::<String>()
225+
.map_err(to_pyerr_for_type("Str", field_name, any))?,
226+
),
227+
tv::schema::Type::U64 => Value::U64(
228+
any.extract::<u64>()
229+
.map_err(to_pyerr_for_type("U64", field_name, any))?,
230+
),
231+
tv::schema::Type::I64 => Value::I64(
232+
any.extract::<i64>()
233+
.map_err(to_pyerr_for_type("I64", field_name, any))?,
234+
),
235+
tv::schema::Type::F64 => Value::F64(
236+
any.extract::<f64>()
237+
.map_err(to_pyerr_for_type("F64", field_name, any))?,
238+
),
239+
tv::schema::Type::Date => {
240+
let py_datetime = any
241+
.downcast::<PyDateTime>()
242+
.map_err(to_pyerr_for_type("DateTime", field_name, any))?;
243+
let datetime = Utc
244+
.with_ymd_and_hms(
245+
py_datetime.get_year(),
246+
py_datetime.get_month().into(),
247+
py_datetime.get_day().into(),
248+
py_datetime.get_hour().into(),
249+
py_datetime.get_minute().into(),
250+
py_datetime.get_second().into(),
251+
)
252+
.single()
253+
.unwrap();
254+
255+
Value::Date(tv::DateTime::from_timestamp_secs(datetime.timestamp()))
256+
}
257+
tv::schema::Type::Facet => Value::Facet(
258+
any.extract::<Facet>()
259+
.map_err(to_pyerr_for_type("Facet", field_name, any))?
260+
.inner,
261+
),
262+
_ => return Err(to_pyerr(format!("Value unsupported {:?}", any))),
263+
};
264+
265+
Ok(value)
266+
}
267+
203268
fn extract_value_single_or_list(any: &PyAny) -> PyResult<Vec<Value>> {
204269
if let Ok(values) = any.downcast::<PyList>() {
205270
values.iter().map(extract_value).collect()
@@ -208,51 +273,135 @@ fn extract_value_single_or_list(any: &PyAny) -> PyResult<Vec<Value>> {
208273
}
209274
}
210275

211-
#[pymethods]
212-
impl Document {
213-
#[new]
214-
#[pyo3(signature = (**kwargs))]
215-
fn new(kwargs: Option<&PyDict>) -> PyResult<Self> {
216-
let mut document = Document::default();
217-
if let Some(field_dict) = kwargs {
218-
document.extend(field_dict)?;
276+
fn extract_value_single_or_list_for_type(
277+
any: &PyAny,
278+
field_type: &tv::schema::FieldType,
279+
field_name: &str,
280+
) -> PyResult<Vec<Value>> {
281+
// Check if a numeric fast field supports multivalues.
282+
let allow_multivalues = match field_type {
283+
tv::schema::FieldType::U64(opts)
284+
| tv::schema::FieldType::I64(opts)
285+
| tv::schema::FieldType::F64(opts) => opts
286+
.get_fastfield_cardinality()
287+
.map(|card| card == tv::schema::Cardinality::MultiValues)
288+
.unwrap_or(true),
289+
_ => true,
290+
};
291+
292+
if let Ok(values) = any.downcast::<PyList>() {
293+
let result = values
294+
.iter()
295+
.map(|any| {
296+
extract_value_for_type(any, field_type.value_type(), field_name)
297+
})
298+
.collect::<PyResult<Vec<Value>>>()?;
299+
300+
if !allow_multivalues && result.len() > 1 {
301+
return Err(to_pyerr(format!(
302+
"Field '{}' does not support multivalues",
303+
field_name
304+
)));
305+
} else {
306+
Ok(result)
219307
}
220-
Ok(document)
308+
} else {
309+
Ok(vec![extract_value_for_type(
310+
any,
311+
field_type.value_type(),
312+
field_name,
313+
)?])
221314
}
315+
}
316+
317+
impl Document {
318+
fn extract_py_values_from_dict(
319+
py_dict: &PyDict,
320+
schema: Option<&Schema>,
321+
out_field_values: &mut BTreeMap<String, Vec<tv::schema::Value>>,
322+
) -> PyResult<()> {
323+
// TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable.
324+
// out_field_values.reserve(py_dict.len());
222325

223-
fn extend(&mut self, py_dict: &PyDict) -> PyResult<()> {
224-
let mut field_values: BTreeMap<String, Vec<tv::schema::Value>> =
225-
BTreeMap::new();
226326
for key_value_any in py_dict.items() {
227327
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
228328
if key_value.len() != 2 {
229329
continue;
230330
}
231-
let key: String = key_value.get_item(0)?.extract()?;
232-
let value_list =
233-
extract_value_single_or_list(key_value.get_item(1)?)?;
234-
field_values.insert(key, value_list);
331+
let key = key_value.get_item(0)?.extract::<String>()?;
332+
333+
let field_type = if let Some(schema) = schema {
334+
let field_type =
335+
schema.inner.get_field(key.as_str()).map(|field| {
336+
schema.inner.get_field_entry(field).field_type()
337+
});
338+
339+
if let Some(field_type) = field_type {
340+
// A field type was found, so validate it after the values are extracted.
341+
Some(field_type)
342+
} else {
343+
// The field does not exist in the schema, so skip over it.
344+
continue;
345+
}
346+
} else {
347+
// No schema was provided, so do not validate anything.
348+
None
349+
};
350+
351+
let value_list = if let Some(field_type) = field_type {
352+
extract_value_single_or_list_for_type(
353+
key_value.get_item(1)?,
354+
field_type,
355+
key.as_str(),
356+
)?
357+
} else {
358+
extract_value_single_or_list(key_value.get_item(1)?)?
359+
};
360+
361+
out_field_values.insert(key, value_list);
235362
}
236363
}
237-
self.field_values.extend(field_values.into_iter());
364+
238365
Ok(())
239366
}
367+
}
368+
369+
#[pymethods]
370+
impl Document {
371+
#[new]
372+
#[pyo3(signature = (**kwargs))]
373+
fn new(kwargs: Option<&PyDict>) -> PyResult<Self> {
374+
let mut document = Document::default();
375+
if let Some(field_dict) = kwargs {
376+
document.extend(field_dict, None)?;
377+
}
378+
Ok(document)
379+
}
380+
381+
fn extend(
382+
&mut self,
383+
py_dict: &PyDict,
384+
schema: Option<&Schema>,
385+
) -> PyResult<()> {
386+
Document::extract_py_values_from_dict(
387+
py_dict,
388+
schema,
389+
&mut self.field_values,
390+
)
391+
}
240392

241393
#[staticmethod]
242-
fn from_dict(py_dict: &PyDict) -> PyResult<Document> {
394+
fn from_dict(
395+
py_dict: &PyDict,
396+
schema: Option<&Schema>,
397+
) -> PyResult<Document> {
243398
let mut field_values: BTreeMap<String, Vec<tv::schema::Value>> =
244399
BTreeMap::new();
245-
for key_value_any in py_dict.items() {
246-
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
247-
if key_value.len() != 2 {
248-
continue;
249-
}
250-
let key: String = key_value.get_item(0)?.extract()?;
251-
let value_list =
252-
extract_value_single_or_list(key_value.get_item(1)?)?;
253-
field_values.insert(key, value_list);
254-
}
255-
}
400+
Document::extract_py_values_from_dict(
401+
py_dict,
402+
schema,
403+
&mut field_values,
404+
)?;
256405
Ok(Document { field_values })
257406
}
258407

0 commit comments

Comments
 (0)