3
3
use crate :: {
4
4
attributes:: {
5
5
self , take_attributes, take_pyo3_options, CrateAttribute , ModuleAttribute , NameAttribute ,
6
+ SubmoduleAttribute ,
6
7
} ,
7
8
get_doc,
8
9
pyclass:: PyClassPyO3Option ,
@@ -27,6 +28,7 @@ pub struct PyModuleOptions {
27
28
krate : Option < CrateAttribute > ,
28
29
name : Option < syn:: Ident > ,
29
30
module : Option < ModuleAttribute > ,
31
+ is_submodule : bool ,
30
32
}
31
33
32
34
impl PyModuleOptions {
@@ -38,6 +40,7 @@ impl PyModuleOptions {
38
40
PyModulePyO3Option :: Name ( name) => options. set_name ( name. value . 0 ) ?,
39
41
PyModulePyO3Option :: Crate ( path) => options. set_crate ( path) ?,
40
42
PyModulePyO3Option :: Module ( module) => options. set_module ( module) ?,
43
+ PyModulePyO3Option :: Submodule ( submod) => options. set_submodule ( submod) ?,
41
44
}
42
45
}
43
46
@@ -73,9 +76,22 @@ impl PyModuleOptions {
73
76
self . module = Some ( name) ;
74
77
Ok ( ( ) )
75
78
}
79
+
80
+ fn set_submodule ( & mut self , submod : SubmoduleAttribute ) -> Result < ( ) > {
81
+ ensure_spanned ! (
82
+ !self . is_submodule,
83
+ submod. span( ) => "`submodule` may only be specified once"
84
+ ) ;
85
+
86
+ self . is_submodule = true ;
87
+ Ok ( ( ) )
88
+ }
76
89
}
77
90
78
- pub fn pymodule_module_impl ( mut module : syn:: ItemMod ) -> Result < TokenStream > {
91
+ pub fn pymodule_module_impl (
92
+ mut module : syn:: ItemMod ,
93
+ mut is_submodule : bool ,
94
+ ) -> Result < TokenStream > {
79
95
let syn:: ItemMod {
80
96
attrs,
81
97
vis,
@@ -100,6 +116,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
100
116
} else {
101
117
name. to_string ( )
102
118
} ;
119
+ is_submodule = is_submodule || options. is_submodule ;
103
120
104
121
let mut module_items = Vec :: new ( ) ;
105
122
let mut module_items_cfg_attrs = Vec :: new ( ) ;
@@ -286,7 +303,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
286
303
}
287
304
}
288
305
289
- let initialization = module_initialization ( & name, ctx) ;
306
+ let initialization = module_initialization ( & name, ctx, is_submodule ) ;
290
307
Ok ( quote ! (
291
308
#( #attrs) *
292
309
#vis mod #ident {
@@ -335,7 +352,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
335
352
let vis = & function. vis ;
336
353
let doc = get_doc ( & function. attrs , None , ctx) ;
337
354
338
- let initialization = module_initialization ( & name, ctx) ;
355
+ let initialization = module_initialization ( & name, ctx, false ) ;
339
356
340
357
// Module function called with optional Python<'_> marker as first arg, followed by the module.
341
358
let mut module_args = Vec :: new ( ) ;
@@ -400,28 +417,34 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
400
417
} )
401
418
}
402
419
403
- fn module_initialization ( name : & syn:: Ident , ctx : & Ctx ) -> TokenStream {
420
+ fn module_initialization ( name : & syn:: Ident , ctx : & Ctx , is_submodule : bool ) -> TokenStream {
404
421
let Ctx { pyo3_path, .. } = ctx;
405
422
let pyinit_symbol = format ! ( "PyInit_{}" , name) ;
406
423
let name = name. to_string ( ) ;
407
424
let pyo3_name = LitCStr :: new ( CString :: new ( name) . unwrap ( ) , Span :: call_site ( ) , ctx) ;
408
425
409
- quote ! {
426
+ let mut base = quote ! {
410
427
#[ doc( hidden) ]
411
428
pub const __PYO3_NAME: & ' static :: std:: ffi:: CStr = #pyo3_name;
412
429
413
430
pub ( super ) struct MakeDef ;
414
431
#[ doc( hidden) ]
415
432
pub static _PYO3_DEF: #pyo3_path:: impl_:: pymodule:: ModuleDef = MakeDef :: make_def( ) ;
416
-
417
- /// This autogenerated function is called by the python interpreter when importing
418
- /// the module.
419
- #[ doc( hidden) ]
420
- #[ export_name = #pyinit_symbol]
421
- pub unsafe extern "C" fn __pyo3_init( ) -> * mut #pyo3_path:: ffi:: PyObject {
422
- #pyo3_path:: impl_:: trampoline:: module_init( |py| _PYO3_DEF. make_module( py) )
423
- }
433
+ } ;
434
+ if !is_submodule {
435
+ base = quote ! {
436
+ #base
437
+
438
+ /// This autogenerated function is called by the python interpreter when importing
439
+ /// the module.
440
+ #[ doc( hidden) ]
441
+ #[ export_name = #pyinit_symbol]
442
+ pub unsafe extern "C" fn __pyo3_init( ) -> * mut #pyo3_path:: ffi:: PyObject {
443
+ #pyo3_path:: impl_:: trampoline:: module_init( |py| _PYO3_DEF. make_module( py) )
444
+ }
445
+ } ;
424
446
}
447
+ base
425
448
}
426
449
427
450
/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
@@ -561,6 +584,7 @@ fn has_pyo3_module_declared<T: Parse>(
561
584
}
562
585
563
586
enum PyModulePyO3Option {
587
+ Submodule ( SubmoduleAttribute ) ,
564
588
Crate ( CrateAttribute ) ,
565
589
Name ( NameAttribute ) ,
566
590
Module ( ModuleAttribute ) ,
@@ -575,6 +599,8 @@ impl Parse for PyModulePyO3Option {
575
599
input. parse ( ) . map ( PyModulePyO3Option :: Crate )
576
600
} else if lookahead. peek ( attributes:: kw:: module) {
577
601
input. parse ( ) . map ( PyModulePyO3Option :: Module )
602
+ } else if lookahead. peek ( attributes:: kw:: submodule) {
603
+ input. parse ( ) . map ( PyModulePyO3Option :: Submodule )
578
604
} else {
579
605
Err ( lookahead. error ( ) )
580
606
}
0 commit comments