@@ -13,7 +13,7 @@ use chrono::{offset::TimeZone, Utc};
13
13
14
14
use tantivy as tv;
15
15
16
- use crate :: { facet:: Facet , to_pyerr} ;
16
+ use crate :: { facet:: Facet , schema :: Schema , to_pyerr} ;
17
17
use serde_json:: Value as JsonValue ;
18
18
use std:: {
19
19
collections:: { BTreeMap , HashMap } ,
@@ -200,6 +200,71 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
200
200
Err ( to_pyerr ( format ! ( "Value unsupported {any:?}" ) ) )
201
201
}
202
202
203
+ pub ( crate ) fn extract_value_for_type (
204
+ any : & PyAny ,
205
+ tv_type : tv:: schema:: Type ,
206
+ field_name : & str ,
207
+ ) -> PyResult < Value > {
208
+ // Helper function to create `PyErr`s returned by this function.
209
+ fn to_pyerr_for_type < ' a , E : std:: error:: Error > (
210
+ type_name : & ' a str ,
211
+ field_name : & ' a str ,
212
+ any : & ' a PyAny ,
213
+ ) -> impl Fn ( E ) -> PyErr + ' a {
214
+ move |_| {
215
+ to_pyerr ( format ! (
216
+ "Expected {} type for field {}, got {:?}" ,
217
+ type_name, field_name, any
218
+ ) )
219
+ }
220
+ }
221
+
222
+ let value = match tv_type {
223
+ tv:: schema:: Type :: Str => Value :: Str (
224
+ any. extract :: < String > ( )
225
+ . map_err ( to_pyerr_for_type ( "Str" , field_name, any) ) ?,
226
+ ) ,
227
+ tv:: schema:: Type :: U64 => Value :: U64 (
228
+ any. extract :: < u64 > ( )
229
+ . map_err ( to_pyerr_for_type ( "U64" , field_name, any) ) ?,
230
+ ) ,
231
+ tv:: schema:: Type :: I64 => Value :: I64 (
232
+ any. extract :: < i64 > ( )
233
+ . map_err ( to_pyerr_for_type ( "I64" , field_name, any) ) ?,
234
+ ) ,
235
+ tv:: schema:: Type :: F64 => Value :: F64 (
236
+ any. extract :: < f64 > ( )
237
+ . map_err ( to_pyerr_for_type ( "F64" , field_name, any) ) ?,
238
+ ) ,
239
+ tv:: schema:: Type :: Date => {
240
+ let py_datetime = any
241
+ . downcast :: < PyDateTime > ( )
242
+ . map_err ( to_pyerr_for_type ( "DateTime" , field_name, any) ) ?;
243
+ let datetime = Utc
244
+ . with_ymd_and_hms (
245
+ py_datetime. get_year ( ) ,
246
+ py_datetime. get_month ( ) . into ( ) ,
247
+ py_datetime. get_day ( ) . into ( ) ,
248
+ py_datetime. get_hour ( ) . into ( ) ,
249
+ py_datetime. get_minute ( ) . into ( ) ,
250
+ py_datetime. get_second ( ) . into ( ) ,
251
+ )
252
+ . single ( )
253
+ . unwrap ( ) ;
254
+
255
+ Value :: Date ( tv:: DateTime :: from_timestamp_secs ( datetime. timestamp ( ) ) )
256
+ }
257
+ tv:: schema:: Type :: Facet => Value :: Facet (
258
+ any. extract :: < Facet > ( )
259
+ . map_err ( to_pyerr_for_type ( "Facet" , field_name, any) ) ?
260
+ . inner ,
261
+ ) ,
262
+ _ => return Err ( to_pyerr ( format ! ( "Value unsupported {:?}" , any) ) ) ,
263
+ } ;
264
+
265
+ Ok ( value)
266
+ }
267
+
203
268
fn extract_value_single_or_list ( any : & PyAny ) -> PyResult < Vec < Value > > {
204
269
if let Ok ( values) = any. downcast :: < PyList > ( ) {
205
270
values. iter ( ) . map ( extract_value) . collect ( )
@@ -208,51 +273,135 @@ fn extract_value_single_or_list(any: &PyAny) -> PyResult<Vec<Value>> {
208
273
}
209
274
}
210
275
211
- #[ pymethods]
212
- impl Document {
213
- #[ new]
214
- #[ pyo3( signature = ( * * kwargs) ) ]
215
- fn new ( kwargs : Option < & PyDict > ) -> PyResult < Self > {
216
- let mut document = Document :: default ( ) ;
217
- if let Some ( field_dict) = kwargs {
218
- document. extend ( field_dict) ?;
276
+ fn extract_value_single_or_list_for_type (
277
+ any : & PyAny ,
278
+ field_type : & tv:: schema:: FieldType ,
279
+ field_name : & str ,
280
+ ) -> PyResult < Vec < Value > > {
281
+ // Check if a numeric fast field supports multivalues.
282
+ let allow_multivalues = match field_type {
283
+ tv:: schema:: FieldType :: U64 ( opts)
284
+ | tv:: schema:: FieldType :: I64 ( opts)
285
+ | tv:: schema:: FieldType :: F64 ( opts) => opts
286
+ . get_fastfield_cardinality ( )
287
+ . map ( |card| card == tv:: schema:: Cardinality :: MultiValues )
288
+ . unwrap_or ( true ) ,
289
+ _ => true ,
290
+ } ;
291
+
292
+ if let Ok ( values) = any. downcast :: < PyList > ( ) {
293
+ let result = values
294
+ . iter ( )
295
+ . map ( |any| {
296
+ extract_value_for_type ( any, field_type. value_type ( ) , field_name)
297
+ } )
298
+ . collect :: < PyResult < Vec < Value > > > ( ) ?;
299
+
300
+ if !allow_multivalues && result. len ( ) > 1 {
301
+ return Err ( to_pyerr ( format ! (
302
+ "Field '{}' does not support multivalues" ,
303
+ field_name
304
+ ) ) ) ;
305
+ } else {
306
+ Ok ( result)
219
307
}
220
- Ok ( document)
308
+ } else {
309
+ Ok ( vec ! [ extract_value_for_type(
310
+ any,
311
+ field_type. value_type( ) ,
312
+ field_name,
313
+ ) ?] )
221
314
}
315
+ }
316
+
317
+ impl Document {
318
+ fn extract_py_values_from_dict (
319
+ py_dict : & PyDict ,
320
+ schema : Option < & Schema > ,
321
+ out_field_values : & mut BTreeMap < String , Vec < tv:: schema:: Value > > ,
322
+ ) -> PyResult < ( ) > {
323
+ // TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable.
324
+ // out_field_values.reserve(py_dict.len());
222
325
223
- fn extend ( & mut self , py_dict : & PyDict ) -> PyResult < ( ) > {
224
- let mut field_values: BTreeMap < String , Vec < tv:: schema:: Value > > =
225
- BTreeMap :: new ( ) ;
226
326
for key_value_any in py_dict. items ( ) {
227
327
if let Ok ( key_value) = key_value_any. downcast :: < PyTuple > ( ) {
228
328
if key_value. len ( ) != 2 {
229
329
continue ;
230
330
}
231
- let key: String = key_value. get_item ( 0 ) ?. extract ( ) ?;
232
- let value_list =
233
- extract_value_single_or_list ( key_value. get_item ( 1 ) ?) ?;
234
- field_values. insert ( key, value_list) ;
331
+ let key = key_value. get_item ( 0 ) ?. extract :: < String > ( ) ?;
332
+
333
+ let field_type = if let Some ( schema) = schema {
334
+ let field_type =
335
+ schema. inner . get_field ( key. as_str ( ) ) . map ( |field| {
336
+ schema. inner . get_field_entry ( field) . field_type ( )
337
+ } ) ;
338
+
339
+ if let Some ( field_type) = field_type {
340
+ // A field type was found, so validate it after the values are extracted.
341
+ Some ( field_type)
342
+ } else {
343
+ // The field does not exist in the schema, so skip over it.
344
+ continue ;
345
+ }
346
+ } else {
347
+ // No schema was provided, so do not validate anything.
348
+ None
349
+ } ;
350
+
351
+ let value_list = if let Some ( field_type) = field_type {
352
+ extract_value_single_or_list_for_type (
353
+ key_value. get_item ( 1 ) ?,
354
+ field_type,
355
+ key. as_str ( ) ,
356
+ ) ?
357
+ } else {
358
+ extract_value_single_or_list ( key_value. get_item ( 1 ) ?) ?
359
+ } ;
360
+
361
+ out_field_values. insert ( key, value_list) ;
235
362
}
236
363
}
237
- self . field_values . extend ( field_values . into_iter ( ) ) ;
364
+
238
365
Ok ( ( ) )
239
366
}
367
+ }
368
+
369
+ #[ pymethods]
370
+ impl Document {
371
+ #[ new]
372
+ #[ pyo3( signature = ( * * kwargs) ) ]
373
+ fn new ( kwargs : Option < & PyDict > ) -> PyResult < Self > {
374
+ let mut document = Document :: default ( ) ;
375
+ if let Some ( field_dict) = kwargs {
376
+ document. extend ( field_dict, None ) ?;
377
+ }
378
+ Ok ( document)
379
+ }
380
+
381
+ fn extend (
382
+ & mut self ,
383
+ py_dict : & PyDict ,
384
+ schema : Option < & Schema > ,
385
+ ) -> PyResult < ( ) > {
386
+ Document :: extract_py_values_from_dict (
387
+ py_dict,
388
+ schema,
389
+ & mut self . field_values ,
390
+ )
391
+ }
240
392
241
393
#[ staticmethod]
242
- fn from_dict ( py_dict : & PyDict ) -> PyResult < Document > {
394
+ fn from_dict (
395
+ py_dict : & PyDict ,
396
+ schema : Option < & Schema > ,
397
+ ) -> PyResult < Document > {
243
398
let mut field_values: BTreeMap < String , Vec < tv:: schema:: Value > > =
244
399
BTreeMap :: new ( ) ;
245
- for key_value_any in py_dict. items ( ) {
246
- if let Ok ( key_value) = key_value_any. downcast :: < PyTuple > ( ) {
247
- if key_value. len ( ) != 2 {
248
- continue ;
249
- }
250
- let key: String = key_value. get_item ( 0 ) ?. extract ( ) ?;
251
- let value_list =
252
- extract_value_single_or_list ( key_value. get_item ( 1 ) ?) ?;
253
- field_values. insert ( key, value_list) ;
254
- }
255
- }
400
+ Document :: extract_py_values_from_dict (
401
+ py_dict,
402
+ schema,
403
+ & mut field_values,
404
+ ) ?;
256
405
Ok ( Document { field_values } )
257
406
}
258
407
0 commit comments