@@ -9,11 +9,11 @@ use pyo3::{
9
9
} ,
10
10
} ;
11
11
12
- use chrono:: { offset:: TimeZone , Utc } ;
12
+ use chrono:: { offset:: TimeZone , DateTime , Utc } ;
13
13
14
14
use tantivy as tv;
15
15
16
- use crate :: { facet:: Facet , to_pyerr} ;
16
+ use crate :: { facet:: Facet , schema :: Schema , to_pyerr} ;
17
17
use serde_json:: Value as JsonValue ;
18
18
use std:: {
19
19
collections:: { BTreeMap , HashMap } ,
@@ -175,18 +175,7 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
175
175
if let Ok ( num) = any. extract :: < f64 > ( ) {
176
176
return Ok ( Value :: F64 ( num) ) ;
177
177
}
178
- if let Ok ( py_datetime) = any. downcast :: < PyDateTime > ( ) {
179
- let datetime = Utc
180
- . with_ymd_and_hms (
181
- py_datetime. get_year ( ) ,
182
- py_datetime. get_month ( ) . into ( ) ,
183
- py_datetime. get_day ( ) . into ( ) ,
184
- py_datetime. get_hour ( ) . into ( ) ,
185
- py_datetime. get_minute ( ) . into ( ) ,
186
- py_datetime. get_second ( ) . into ( ) ,
187
- )
188
- . single ( )
189
- . unwrap ( ) ;
178
+ if let Ok ( datetime) = any. extract :: < DateTime < Utc > > ( ) {
190
179
return Ok ( Value :: Date ( tv:: DateTime :: from_timestamp_secs (
191
180
datetime. timestamp ( ) ,
192
181
) ) ) ;
@@ -200,6 +189,60 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
200
189
Err ( to_pyerr ( format ! ( "Value unsupported {any:?}" ) ) )
201
190
}
202
191
192
+ pub ( crate ) fn extract_value_for_type (
193
+ any : & PyAny ,
194
+ tv_type : tv:: schema:: Type ,
195
+ field_name : & str ,
196
+ ) -> PyResult < Value > {
197
+ // Helper function to create `PyErr`s returned by this function.
198
+ fn to_pyerr_for_type < ' a , E : std:: error:: Error > (
199
+ type_name : & ' a str ,
200
+ field_name : & ' a str ,
201
+ any : & ' a PyAny ,
202
+ ) -> impl Fn ( E ) -> PyErr + ' a {
203
+ move |_| {
204
+ to_pyerr ( format ! (
205
+ "Expected {} type for field {}, got {:?}" ,
206
+ type_name, field_name, any
207
+ ) )
208
+ }
209
+ }
210
+
211
+ let value = match tv_type {
212
+ tv:: schema:: Type :: Str => Value :: Str (
213
+ any. extract :: < String > ( )
214
+ . map_err ( to_pyerr_for_type ( "Str" , field_name, any) ) ?,
215
+ ) ,
216
+ tv:: schema:: Type :: U64 => Value :: U64 (
217
+ any. extract :: < u64 > ( )
218
+ . map_err ( to_pyerr_for_type ( "U64" , field_name, any) ) ?,
219
+ ) ,
220
+ tv:: schema:: Type :: I64 => Value :: I64 (
221
+ any. extract :: < i64 > ( )
222
+ . map_err ( to_pyerr_for_type ( "I64" , field_name, any) ) ?,
223
+ ) ,
224
+ tv:: schema:: Type :: F64 => Value :: F64 (
225
+ any. extract :: < f64 > ( )
226
+ . map_err ( to_pyerr_for_type ( "F64" , field_name, any) ) ?,
227
+ ) ,
228
+ tv:: schema:: Type :: Date => {
229
+ let datetime = any
230
+ . extract :: < DateTime < Utc > > ( )
231
+ . map_err ( to_pyerr_for_type ( "DateTime" , field_name, any) ) ?;
232
+
233
+ Value :: Date ( tv:: DateTime :: from_timestamp_secs ( datetime. timestamp ( ) ) )
234
+ }
235
+ tv:: schema:: Type :: Facet => Value :: Facet (
236
+ any. extract :: < Facet > ( )
237
+ . map_err ( to_pyerr_for_type ( "Facet" , field_name, any) ) ?
238
+ . inner ,
239
+ ) ,
240
+ _ => return Err ( to_pyerr ( format ! ( "Value unsupported {:?}" , any) ) ) ,
241
+ } ;
242
+
243
+ Ok ( value)
244
+ }
245
+
203
246
fn extract_value_single_or_list ( any : & PyAny ) -> PyResult < Vec < Value > > {
204
247
if let Ok ( values) = any. downcast :: < PyList > ( ) {
205
248
values. iter ( ) . map ( extract_value) . collect ( )
@@ -208,51 +251,119 @@ fn extract_value_single_or_list(any: &PyAny) -> PyResult<Vec<Value>> {
208
251
}
209
252
}
210
253
254
+ fn extract_value_single_or_list_for_type (
255
+ any : & PyAny ,
256
+ field_type : & tv:: schema:: FieldType ,
257
+ field_name : & str ,
258
+ ) -> PyResult < Vec < Value > > {
259
+ // Check if a numeric fast field supports multivalues.
260
+ if let Ok ( values) = any. downcast :: < PyList > ( ) {
261
+ values
262
+ . iter ( )
263
+ . map ( |any| {
264
+ extract_value_for_type ( any, field_type. value_type ( ) , field_name)
265
+ } )
266
+ . collect :: < PyResult < Vec < Value > > > ( )
267
+ } else {
268
+ Ok ( vec ! [ extract_value_for_type(
269
+ any,
270
+ field_type. value_type( ) ,
271
+ field_name,
272
+ ) ?] )
273
+ }
274
+ }
275
+
276
+ impl Document {
277
+ fn extract_py_values_from_dict (
278
+ py_dict : & PyDict ,
279
+ schema : Option < & Schema > ,
280
+ out_field_values : & mut BTreeMap < String , Vec < tv:: schema:: Value > > ,
281
+ ) -> PyResult < ( ) > {
282
+ // TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable.
283
+ // out_field_values.reserve(py_dict.len());
284
+
285
+ for key_value_any in py_dict. items ( ) {
286
+ if let Ok ( key_value) = key_value_any. downcast :: < PyTuple > ( ) {
287
+ if key_value. len ( ) != 2 {
288
+ continue ;
289
+ }
290
+ let key = key_value. get_item ( 0 ) ?. extract :: < String > ( ) ?;
291
+
292
+ let field_type = if let Some ( schema) = schema {
293
+ let field_type = schema
294
+ . inner
295
+ . get_field ( key. as_str ( ) )
296
+ . map ( |field| {
297
+ schema. inner . get_field_entry ( field) . field_type ( )
298
+ } )
299
+ . ok ( ) ;
300
+
301
+ if let Some ( field_type) = field_type {
302
+ // A field type was found, so validate it after the values are extracted.
303
+ Some ( field_type)
304
+ } else {
305
+ // The field does not exist in the schema, so skip over it.
306
+ continue ;
307
+ }
308
+ } else {
309
+ // No schema was provided, so do not validate anything.
310
+ None
311
+ } ;
312
+
313
+ let value_list = if let Some ( field_type) = field_type {
314
+ extract_value_single_or_list_for_type (
315
+ key_value. get_item ( 1 ) ?,
316
+ field_type,
317
+ key. as_str ( ) ,
318
+ ) ?
319
+ } else {
320
+ extract_value_single_or_list ( key_value. get_item ( 1 ) ?) ?
321
+ } ;
322
+
323
+ out_field_values. insert ( key, value_list) ;
324
+ }
325
+ }
326
+
327
+ Ok ( ( ) )
328
+ }
329
+ }
330
+
211
331
#[ pymethods]
212
332
impl Document {
213
333
#[ new]
214
334
#[ pyo3( signature = ( * * kwargs) ) ]
215
335
fn new ( kwargs : Option < & PyDict > ) -> PyResult < Self > {
216
336
let mut document = Document :: default ( ) ;
217
337
if let Some ( field_dict) = kwargs {
218
- document. extend ( field_dict) ?;
338
+ document. extend ( field_dict, None ) ?;
219
339
}
220
340
Ok ( document)
221
341
}
222
342
223
- fn extend ( & mut self , py_dict : & PyDict ) -> PyResult < ( ) > {
224
- let mut field_values: BTreeMap < String , Vec < tv:: schema:: Value > > =
225
- BTreeMap :: new ( ) ;
226
- for key_value_any in py_dict. items ( ) {
227
- if let Ok ( key_value) = key_value_any. downcast :: < PyTuple > ( ) {
228
- if key_value. len ( ) != 2 {
229
- continue ;
230
- }
231
- let key: String = key_value. get_item ( 0 ) ?. extract ( ) ?;
232
- let value_list =
233
- extract_value_single_or_list ( key_value. get_item ( 1 ) ?) ?;
234
- field_values. insert ( key, value_list) ;
235
- }
236
- }
237
- self . field_values . extend ( field_values. into_iter ( ) ) ;
238
- Ok ( ( ) )
343
+ fn extend (
344
+ & mut self ,
345
+ py_dict : & PyDict ,
346
+ schema : Option < & Schema > ,
347
+ ) -> PyResult < ( ) > {
348
+ Document :: extract_py_values_from_dict (
349
+ py_dict,
350
+ schema,
351
+ & mut self . field_values ,
352
+ )
239
353
}
240
354
241
355
#[ staticmethod]
242
- fn from_dict ( py_dict : & PyDict ) -> PyResult < Document > {
356
+ fn from_dict (
357
+ py_dict : & PyDict ,
358
+ schema : Option < & Schema > ,
359
+ ) -> PyResult < Document > {
243
360
let mut field_values: BTreeMap < String , Vec < tv:: schema:: Value > > =
244
361
BTreeMap :: new ( ) ;
245
- for key_value_any in py_dict. items ( ) {
246
- if let Ok ( key_value) = key_value_any. downcast :: < PyTuple > ( ) {
247
- if key_value. len ( ) != 2 {
248
- continue ;
249
- }
250
- let key: String = key_value. get_item ( 0 ) ?. extract ( ) ?;
251
- let value_list =
252
- extract_value_single_or_list ( key_value. get_item ( 1 ) ?) ?;
253
- field_values. insert ( key, value_list) ;
254
- }
255
- }
362
+ Document :: extract_py_values_from_dict (
363
+ py_dict,
364
+ schema,
365
+ & mut field_values,
366
+ ) ?;
256
367
Ok ( Document { field_values } )
257
368
}
258
369
0 commit comments