Skip to content

refs #4286 -- allow setting submodule on declarative pymodules #4301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion guide/src/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ The `#[pymodule]` macro automatically sets the `module` attribute of the `#[pycl
For nested modules, the name of the parent module is automatically added.
In the following example, the `Unit` class will have for `module` `my_extension.submodule` because it is properly nested
but the `Ext` class will have for `module` the default `builtins` because it not nested.

You can provide the `submodule` argument to `pymodule()` for modules that are not top-level modules.
```rust
# mod declarative_module_module_attr_test {
use pyo3::prelude::*;
Expand All @@ -168,7 +170,7 @@ mod my_extension {
#[pymodule_export]
use super::Ext;

#[pymodule]
#[pymodule(submodule)]
mod submodule {
use super::*;
// This is a submodule
Expand Down
1 change: 1 addition & 0 deletions newsfragments/4301.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
allow setting `submodule` on declarative `#[pymodule]`s
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub mod kw {
syn::custom_keyword!(set_all);
syn::custom_keyword!(signature);
syn::custom_keyword!(subclass);
syn::custom_keyword!(submodule);
syn::custom_keyword!(text_signature);
syn::custom_keyword!(transparent);
syn::custom_keyword!(unsendable);
Expand Down Expand Up @@ -178,6 +179,7 @@ pub type ModuleAttribute = KeywordAttribute<kw::module, LitStr>;
pub type NameAttribute = KeywordAttribute<kw::name, NameLitStr>;
pub type RenameAllAttribute = KeywordAttribute<kw::rename_all, RenamingRuleLitStr>;
pub type TextSignatureAttribute = KeywordAttribute<kw::text_signature, TextSignatureAttributeValue>;
pub type SubmoduleAttribute = kw::submodule;

impl<K: Parse + std::fmt::Debug, V: Parse> Parse for KeywordAttribute<K, V> {
fn parse(input: ParseStream<'_>) -> Result<Self> {
Expand Down
55 changes: 42 additions & 13 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::{
attributes::{
self, take_attributes, take_pyo3_options, CrateAttribute, ModuleAttribute, NameAttribute,
SubmoduleAttribute,
},
get_doc,
pyclass::PyClassPyO3Option,
Expand All @@ -27,6 +28,7 @@ pub struct PyModuleOptions {
krate: Option<CrateAttribute>,
name: Option<syn::Ident>,
module: Option<ModuleAttribute>,
is_submodule: bool,
}

impl PyModuleOptions {
Expand All @@ -38,6 +40,7 @@ impl PyModuleOptions {
PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?,
PyModulePyO3Option::Crate(path) => options.set_crate(path)?,
PyModulePyO3Option::Module(module) => options.set_module(module)?,
PyModulePyO3Option::Submodule(submod) => options.set_submodule(submod)?,
}
}

Expand Down Expand Up @@ -73,9 +76,22 @@ impl PyModuleOptions {
self.module = Some(name);
Ok(())
}

fn set_submodule(&mut self, submod: SubmoduleAttribute) -> Result<()> {
ensure_spanned!(
!self.is_submodule,
submod.span() => "`submodule` may only be specified once"
);

self.is_submodule = true;
Ok(())
}
}

pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
pub fn pymodule_module_impl(
mut module: syn::ItemMod,
mut is_submodule: bool,
) -> Result<TokenStream> {
let syn::ItemMod {
attrs,
vis,
Expand All @@ -100,6 +116,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
} else {
name.to_string()
};
is_submodule = is_submodule || options.is_submodule;

let mut module_items = Vec::new();
let mut module_items_cfg_attrs = Vec::new();
Expand Down Expand Up @@ -297,7 +314,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
)
}
}};
let initialization = module_initialization(&name, ctx, module_def);
let initialization = module_initialization(&name, ctx, module_def, is_submodule);
Ok(quote!(
#(#attrs)*
#vis mod #ident {
Expand Down Expand Up @@ -331,7 +348,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
let vis = &function.vis;
let doc = get_doc(&function.attrs, None, ctx);

let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() });
let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() }, false);

// Module function called with optional Python<'_> marker as first arg, followed by the module.
let mut module_args = Vec::new();
Expand Down Expand Up @@ -396,28 +413,37 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
})
}

fn module_initialization(name: &syn::Ident, ctx: &Ctx, module_def: TokenStream) -> TokenStream {
fn module_initialization(
name: &syn::Ident,
ctx: &Ctx,
module_def: TokenStream,
is_submodule: bool,
) -> TokenStream {
let Ctx { pyo3_path, .. } = ctx;
let pyinit_symbol = format!("PyInit_{}", name);
let name = name.to_string();
let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);

quote! {
let mut result = quote! {
#[doc(hidden)]
pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;

pub(super) struct MakeDef;
#[doc(hidden)]
pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;

/// This autogenerated function is called by the python interpreter when importing
/// the module.
#[doc(hidden)]
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
}
};
if !is_submodule {
result.extend(quote! {
/// This autogenerated function is called by the python interpreter when importing
/// the module.
#[doc(hidden)]
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
}
});
}
result
}

/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
Expand Down Expand Up @@ -557,6 +583,7 @@ fn has_pyo3_module_declared<T: Parse>(
}

enum PyModulePyO3Option {
Submodule(SubmoduleAttribute),
Crate(CrateAttribute),
Name(NameAttribute),
Module(ModuleAttribute),
Expand All @@ -571,6 +598,8 @@ impl Parse for PyModulePyO3Option {
input.parse().map(PyModulePyO3Option::Crate)
} else if lookahead.peek(attributes::kw::module) {
input.parse().map(PyModulePyO3Option::Module)
} else if lookahead.peek(attributes::kw::submodule) {
input.parse().map(PyModulePyO3Option::Submodule)
} else {
Err(lookahead.error())
}
Expand Down
24 changes: 20 additions & 4 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use proc_macro2::{Span, TokenStream as TokenStream2};
use pyo3_macros_backend::{
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
pymodule_function_impl, pymodule_module_impl, PyClassArgs, PyClassMethodsType,
Expand Down Expand Up @@ -35,10 +35,26 @@ use syn::{parse::Nothing, parse_macro_input, Item};
/// [1]: https://pyo3.rs/latest/module.html
#[proc_macro_attribute]
pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream {
parse_macro_input!(args as Nothing);
match parse_macro_input!(input as Item) {
Item::Mod(module) => pymodule_module_impl(module),
Item::Fn(function) => pymodule_function_impl(function),
Item::Mod(module) => {
let is_submodule = match parse_macro_input!(args as Option<syn::Ident>) {
Some(i) if i == "submodule" => true,
Some(_) => {
return syn::Error::new(
Span::call_site(),
"#[pymodule] only accepts submodule as an argument",
)
.into_compile_error()
.into();
}
None => false,
};
pymodule_module_impl(module, is_submodule)
}
Item::Fn(function) => {
parse_macro_input!(args as Nothing);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it seems reasonable to me for function modules to also support submodule?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's relevant for function modules -- there's never been a need to use #[pymodule] for submodules with functions.

pymodule_function_impl(function)
}
unsupported => Err(syn::Error::new_spanned(
unsupported,
"#[pymodule] only supports modules and functions.",
Expand Down
10 changes: 9 additions & 1 deletion tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ create_exception!(
"Some description."
);

#[pymodule]
#[pyo3(submodule)]
mod external_submodule {}

/// A module written using declarative syntax.
#[pymodule]
mod declarative_module {
Expand All @@ -70,6 +74,9 @@ mod declarative_module {
#[pymodule_export]
use super::some_module::SomeException;

#[pymodule_export]
use super::external_submodule;

#[pymodule]
mod inner {
use super::*;
Expand Down Expand Up @@ -108,7 +115,7 @@ mod declarative_module {
}
}

#[pymodule]
#[pymodule(submodule)]
#[pyo3(module = "custom_root")]
mod inner_custom_root {
use super::*;
Expand Down Expand Up @@ -174,6 +181,7 @@ fn test_declarative_module() {
py_assert!(py, m, "hasattr(m, 'LocatedClass')");
py_assert!(py, m, "isinstance(m.inner.Struct(), m.inner.Struct)");
py_assert!(py, m, "isinstance(m.inner.Enum.A, m.inner.Enum)");
py_assert!(py, m, "hasattr(m, 'external_submodule')")
})
}

Expand Down
Loading