Skip to content

Commit 569df7b

Browse files
Refactor nnUNet documentation and examples for clarity; update fold parameter type in tests
1 parent 43c694b commit 569df7b

File tree

2 files changed

+43
-39
lines changed

2 files changed

+43
-39
lines changed

monai/bundle/nnunet.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,25 @@ def get_nnunet_trainer(
4949
The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
5050
optimizer, loss function, DataLoader, etc.
5151
52-
```python
53-
from monai.apps import SupervisedTrainer
54-
from monai.bundle.nnunet import get_nnunet_trainer
55-
56-
dataset_name_or_id = 'Task101_PROSTATE'
57-
fold = 0
58-
configuration = '3d_fullres'
59-
nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
60-
61-
trainer = SupervisedTrainer(
62-
device=nnunet_trainer.device,
63-
max_epochs=nnunet_trainer.num_epochs,
64-
train_data_loader=nnunet_trainer.dataloader_train,
65-
network=nnunet_trainer.network,
66-
optimizer=nnunet_trainer.optimizer,
67-
loss_function=nnunet_trainer.loss_function,
68-
epoch_length=nnunet_trainer.num_iterations_per_epoch,
69-
70-
```
52+
Example::
53+
54+
from monai.apps import SupervisedTrainer
55+
from monai.bundle.nnunet import get_nnunet_trainer
56+
57+
dataset_name_or_id = 'Task101_PROSTATE'
58+
fold = 0
59+
configuration = '3d_fullres'
60+
nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
61+
62+
trainer = SupervisedTrainer(
63+
device=nnunet_trainer.device,
64+
max_epochs=nnunet_trainer.num_epochs,
65+
train_data_loader=nnunet_trainer.dataloader_train,
66+
network=nnunet_trainer.network,
67+
optimizer=nnunet_trainer.optimizer,
68+
loss_function=nnunet_trainer.loss_function,
69+
epoch_length=nnunet_trainer.num_iterations_per_epoch,
70+
)
7171
7272
Parameters
7373
----------
@@ -162,16 +162,19 @@ class ModelnnUNetWrapper(torch.nn.Module):
162162
The folder path where the model and related files are stored.
163163
model_name : str, optional
164164
The name of the model file, by default "model.pt".
165+
165166
Attributes
166167
----------
167-
predictor : object
168-
The predictor object used for inference.
168+
predictor : nnUNetPredictor
169+
The nnUNet predictor object used for inference.
169170
network_weights : torch.nn.Module
170171
The network weights of the model.
172+
171173
Methods
172174
-------
173175
forward(x)
174176
Perform forward pass and prediction on the input data.
177+
175178
Notes
176179
-----
177180
This class integrates nnUNet model with MONAI framework by loading necessary configurations,
@@ -183,7 +186,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
183186
self.predictor = predictor
184187

185188
model_training_output_dir = model_folder
186-
use_folds = "0"
189+
use_folds = ["0"]
187190

188191
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
189192

@@ -290,27 +293,28 @@ def forward(self, x):
290293

291294
def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
292295
"""
293-
Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor.
296+
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
294297
The model folder should contain the following files, created during training:
295-
- dataset.json: from the nnUNet results folder.
296-
- plans.json: from the nnUNet results folder.
297-
- nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
298-
(`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`).
299-
- model.pt: The checkpoint file containing the model weights.
300-
298+
299+
- dataset.json: from the nnUNet results folder
300+
- plans.json: from the nnUNet results folder
301+
- nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`)
302+
- model.pt: The checkpoint file containing the model weights.
303+
301304
The returned wrapper object can be used for inference with MONAI framework:
302-
```python
303-
from monai.bundle.nnunet import get_nnunet_monai_predictor
305+
306+
Example::
307+
308+
from monai.bundle.nnunet import get_nnunet_monai_predictor
304309
305-
model_folder = 'path/to/monai_bundle/model'
306-
model_name = 'model.pt'
307-
wrapper = get_nnunet_monai_predictor(model_folder, model_name)
310+
model_folder = 'path/to/monai_bundle/model'
311+
model_name = 'model.pt'
312+
wrapper = get_nnunet_monai_predictor(model_folder, model_name)
308313
309-
# Perform inference
310-
input_data = ...
311-
output = wrapper(input_data)
314+
# Perform inference
315+
input_data = ...
316+
output = wrapper(input_data)
312317
313-
```
314318
315319
Parameters
316320
----------

tests/test_integration_nnunet_bundle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_nnunet_bundle(self) -> None:
107107
print("Loss Function: ", nnunet_trainer.loss)
108108
print("LR Scheduler: ", nnunet_trainer.lr_scheduler)
109109
print("Device: ", nnunet_trainer.device)
110-
runner.train_single_model("3d_fullres", fold="0")
110+
runner.train_single_model("3d_fullres", fold=0)
111111

112112
nnunet_config = {"dataset_name_or_id": "001", "nnunet_trainer": "nnUNetTrainer_1epoch"}
113113
self.bundle_root = os.path.join("bundle_root")

0 commit comments

Comments
 (0)